From f07fdf96b4ccc0a9c6e87fe2562e9359be85d3af Mon Sep 17 00:00:00 2001 From: Paul McDaniel Date: Fri, 15 Nov 2019 10:54:44 -0800 Subject: [PATCH 1/3] model moved over. everything builds clean. step ! --- cmake/winml.cmake | 26 ++-- winml/lib/Api.Core/ModelInfo.cpp | 116 --------------- winml/lib/Api.Core/WinMLAdapter.cpp | 195 +++++++++++++++++++++++++- winml/lib/Api.Core/inc/ModelInfo.h | 20 --- winml/lib/Api.Core/inc/WinMLAdapter.h | 23 ++- winml/lib/Api/LearningModel.cpp | 74 ++++------ winml/lib/Api/LearningModel.h | 12 +- 7 files changed, 251 insertions(+), 215 deletions(-) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 94e7db025bbd7..3fa0f4e72afcd 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -471,24 +471,24 @@ if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") endif("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") # Link libraries -target_link_libraries(winml_dll PRIVATE libprotobuf) -target_link_libraries(winml_dll PRIVATE onnx) -target_link_libraries(winml_dll PRIVATE onnxruntime_common) -target_link_libraries(winml_dll PRIVATE onnxruntime_graph) -target_link_libraries(winml_dll PRIVATE onnxruntime_framework) -target_link_libraries(winml_dll PRIVATE onnxruntime_mlas) -target_link_libraries(winml_dll PRIVATE onnxruntime_optimizer) -target_link_libraries(winml_dll PRIVATE onnxruntime_providers) -target_link_libraries(winml_dll PRIVATE onnxruntime_providers_dml) -target_link_libraries(winml_dll PRIVATE onnxruntime_session) -target_link_libraries(winml_dll PRIVATE onnxruntime_util) -target_link_libraries(winml_dll PRIVATE onnx_proto) +#target_link_libraries(winml_dll PRIVATE libprotobuf) +#target_link_libraries(winml_dll PRIVATE onnx) +#target_link_libraries(winml_dll PRIVATE onnxruntime_common) +#target_link_libraries(winml_dll PRIVATE onnxruntime_graph) +#target_link_libraries(winml_dll PRIVATE onnxruntime_framework) +#target_link_libraries(winml_dll PRIVATE onnxruntime_mlas) +#target_link_libraries(winml_dll PRIVATE onnxruntime_optimizer) +#target_link_libraries(winml_dll PRIVATE onnxruntime_providers) +#target_link_libraries(winml_dll PRIVATE onnxruntime_providers_dml) +#target_link_libraries(winml_dll PRIVATE onnxruntime_session) +#target_link_libraries(winml_dll PRIVATE onnxruntime_util) +#target_link_libraries(winml_dll PRIVATE onnx_proto) target_link_libraries(winml_dll PRIVATE onnxruntime) target_link_libraries(winml_dll PRIVATE re2) target_link_libraries(winml_dll PRIVATE wil) target_link_libraries(winml_dll PRIVATE windowsapp.lib) target_link_libraries(winml_dll PRIVATE winml_lib_api) -target_link_libraries(winml_dll PRIVATE winml_lib_core) +#target_link_libraries(winml_dll PRIVATE winml_lib_core) target_link_libraries(winml_dll PRIVATE winml_lib_image) target_link_libraries(winml_dll PRIVATE winml_lib_telemetry) target_link_libraries(winml_dll PRIVATE ${DBGHELP}) diff --git a/winml/lib/Api.Core/ModelInfo.cpp b/winml/lib/Api.Core/ModelInfo.cpp index 5d34cb50623b3..46e9bc29bfa3d 100644 --- a/winml/lib/Api.Core/ModelInfo.cpp +++ b/winml/lib/Api.Core/ModelInfo.cpp @@ -10,121 +10,5 @@ using namespace Windows::AI::MachineLearning; -static std::vector -GetAllNodeOutputs(const onnx::ModelProto& model_proto) { - std::vector nodes_outputs; - auto& graph = model_proto.graph(); - auto& nodes = graph.node(); - for (auto& node : nodes) { - for (auto& node_output : node.output()) { - nodes_outputs.push_back(node_output.c_str()); - } - } - return nodes_outputs; -} - -static std::vector -GetInitializers(const onnx::ModelProto& model_proto) { - std::vector initializers; - auto& graph = model_proto.graph(); - auto& graph_initializers = graph.initializer(); - for (auto& initializer : graph_initializers) { - initializers.push_back(initializer.name().c_str()); - } - return initializers; -} - -static std::vector -GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { - auto initializers = GetInitializers(model_proto); - - std::vector inputs_without_initializers; - auto& graph = model_proto.graph(); - auto& inputs = graph.input(); - for (auto& input : inputs) { - if (input.has_name() && input.has_type()) { - auto found_it = std::find_if( - std::begin(initializers), - std::end(initializers), - [&](auto& initializer) { - return std::strcmp(initializer, input.name().c_str()) == 0; - }); - - auto is_initializer = found_it != std::end(initializers); - if (!is_initializer) { - inputs_without_initializers.push_back(&input); - } - } - } - return inputs_without_initializers; -} - -static std::vector -GetOutputs(const onnx::ModelProto& model_proto) { - std::vector outputs_with_name; - auto& graph = model_proto.graph(); - auto& outputs = graph.output(); - for (auto& output : outputs) { - if (output.has_name() && output.has_type()) { - outputs_with_name.push_back(&output); - } - } - return outputs_with_name; -} - -ModelInfo::ModelInfo( - const onnx::ModelProto* model_proto) { - Initialize(model_proto); -} - -void ModelInfo::Initialize( - const onnx::ModelProto* model_proto) { - // metadata - for (auto& prop : model_proto->metadata_props()) { - model_metadata_[prop.key()] = prop.value(); - } - - WinML::FeatureDescriptorFactory builder(model_metadata_); - - // Create inputs - auto inputs = GetInputsWithoutInitializers(*model_proto); - input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs); - - // Create outputs - auto outputs = ::GetOutputs(*model_proto); - output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs); - - // author - auto has_producer_name = model_proto->has_producer_name(); - author_ = has_producer_name - ? model_proto->producer_name() - : ""; - - // domain - auto has_domain = model_proto->has_domain(); - domain_ = has_domain - ? model_proto->domain() - : ""; - - // name - auto has_graph = model_proto->has_graph(); - auto graph_has_name = model_proto->graph().has_name(); - auto is_name_available = has_graph && graph_has_name; - name_ = is_name_available - ? model_proto->graph().name() - : ""; - - // description - auto has_description = model_proto->has_doc_string(); - description_ = has_description - ? model_proto->doc_string() - : ""; - - // version - auto has_version = model_proto->has_model_version(); - version_ = has_version - ? model_proto->model_version() - : 0; -} diff --git a/winml/lib/Api.Core/WinMLAdapter.cpp b/winml/lib/Api.Core/WinMLAdapter.cpp index 105705b7a917a..b40318f381537 100644 --- a/winml/lib/Api.Core/WinMLAdapter.cpp +++ b/winml/lib/Api.Core/WinMLAdapter.cpp @@ -4,10 +4,12 @@ #include "pch.h" #include "inc/WinMLAdapter.h" #include "inc/CustomRegistryHelper.h" +#include "PheonixSingleton.h" #include "inc/LotusEnvironment.h" #include "inc/AbiCustomRegistryImpl.h" #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" #include "core/providers/dml/GraphTransformers/GraphTransformerHelpers.h" +#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h" #include "LearningModelDevice.h" #include "TensorFeatureDescriptor.h" @@ -25,6 +27,8 @@ #include "ZeroCopyInputStreamWrapper.h" #include "google/protobuf/io/zero_copy_stream_impl.h" +#include "FeatureDescriptorFactory.h" + using namespace winrt::Windows::AI::MachineLearning; @@ -111,7 +115,7 @@ class AbiSafeOrtValue : public Microsoft::WRL::RuntimeClass < *tensor = tensor_outer.Detach(); return S_OK; } -}; +}; // class AbiSafeOrtValue class ModelProto : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, @@ -128,12 +132,178 @@ class ModelProto : public Microsoft::WRL::RuntimeClass< private: std::shared_ptr model_proto_; -}; +}; // class ModelProto + + +class ModelInfo : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, + IModelInfo> { + +private: + std::string author_; + std::string name_; + std::string domain_; + std::string description_; + int64_t version_; + std::unordered_map model_metadata_; + wfc::IVector input_features_; + wfc::IVector output_features_; + +public: + + ModelInfo(const onnx::ModelProto* model_proto) { + Initialize(model_proto); + } + + std::string STDMETHODCALLTYPE author() override { + return author_; + } + std::string STDMETHODCALLTYPE name() override { + return name_; + } + std::string STDMETHODCALLTYPE domain() override { + return domain_; + } + std::string STDMETHODCALLTYPE description() override { + return description_; + } + int64_t STDMETHODCALLTYPE version() override { + return version_; + } + std::unordered_map STDMETHODCALLTYPE model_metadata() override { + return model_metadata_; + } + wfc::IVector STDMETHODCALLTYPE input_features() override { + return input_features_; + } + wfc::IVector STDMETHODCALLTYPE output_features() override { + return output_features_; + } + + static std::vector + GetAllNodeOutputs(const onnx::ModelProto& model_proto) { + std::vector nodes_outputs; + auto& graph = model_proto.graph(); + auto& nodes = graph.node(); + for (auto& node : nodes) { + for (auto& node_output : node.output()) { + nodes_outputs.push_back(node_output.c_str()); + } + } + return nodes_outputs; + } + + static std::vector + GetInitializers(const onnx::ModelProto& model_proto) { + std::vector initializers; + auto& graph = model_proto.graph(); + auto& graph_initializers = graph.initializer(); + for (auto& initializer : graph_initializers) { + initializers.push_back(initializer.name().c_str()); + } + return initializers; + } + + static std::vector + GetInputsWithoutInitializers(const onnx::ModelProto& model_proto) { + auto initializers = GetInitializers(model_proto); + + std::vector inputs_without_initializers; + auto& graph = model_proto.graph(); + auto& inputs = graph.input(); + for (auto& input : inputs) { + if (input.has_name() && input.has_type()) { + auto found_it = std::find_if( + std::begin(initializers), + std::end(initializers), + [&](auto& initializer) { + return std::strcmp(initializer, input.name().c_str()) == 0; + }); + + auto is_initializer = found_it != std::end(initializers); + if (!is_initializer) { + inputs_without_initializers.push_back(&input); + } + } + } + return inputs_without_initializers; + } + + static + std::vector GetOutputs(const onnx::ModelProto& model_proto) { + std::vector outputs_with_name; + auto& graph = model_proto.graph(); + auto& outputs = graph.output(); + for (auto& output : outputs) { + if (output.has_name() && output.has_type()) { + outputs_with_name.push_back(&output); + } + } + return outputs_with_name; + } + +private: + void Initialize(const onnx::ModelProto* model_proto) { + // metadata + for (auto& prop : model_proto->metadata_props()) { + model_metadata_[prop.key()] = prop.value(); + } + + WinML::FeatureDescriptorFactory builder(model_metadata_); + + // Create inputs + auto inputs = GetInputsWithoutInitializers(*model_proto); + input_features_ = builder.CreateDescriptorsFromValueInfoProtos(inputs); + + // Create outputs + auto outputs = GetOutputs(*model_proto); + output_features_ = builder.CreateDescriptorsFromValueInfoProtos(outputs); + + // author + auto has_producer_name = model_proto->has_producer_name(); + author_ = has_producer_name + ? model_proto->producer_name() + : ""; + + // domain + auto has_domain = model_proto->has_domain(); + domain_ = has_domain + ? model_proto->domain() + : ""; + + // name + auto has_graph = model_proto->has_graph(); + auto graph_has_name = model_proto->graph().has_name(); + auto is_name_available = has_graph && graph_has_name; + name_ = is_name_available + ? model_proto->graph().name() + : ""; + + // description + auto has_description = model_proto->has_doc_string(); + description_ = has_description + ? model_proto->doc_string() + : ""; + + // version + auto has_version = model_proto->has_model_version(); + version_ = has_version + ? model_proto->model_version() + : 0; + } +}; // class ModelInfo class WinMLAdapter : public Microsoft::WRL::RuntimeClass< - Microsoft::WRL::RuntimeClassFlags, + Microsoft::WRL::RuntimeClassFlags, IWinMLAdapter> { +private: + std::shared_ptr lotus_environment_; + public: + WinMLAdapter() : lotus_environment_(PheonixSingleton()) { + + } + // factory methods for creating an ort model from a path HRESULT STDMETHODCALLTYPE CreateModelProto( const char* path, @@ -188,6 +358,12 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass< return model_proto_outer.CopyTo(__uuidof(IModelProto), (void**)model_proto); } + HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) override { + auto model_info_outer = wil::MakeOrThrow(model_proto->get()); + return model_info_outer.CopyTo(__uuidof(IModelInfo), (void**)model_info); + } + + void STDMETHODCALLTYPE EnableDebugOutput() override { WinML::CWinMLLogSink::EnableDebugOutput(); } @@ -516,6 +692,19 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass< return S_OK; } + // Override select shape inference functions which are incomplete in ONNX with versions that are complete, + // and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being + // deferred until first evaluation. It also prevents a situation where inference functions in externally + // registered schema are reachable only after upstream schema have been revised in a later OS release, + // which would be a compatibility risk. + HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() override { + static std::once_flag schema_override_once_flag; + std::call_once(schema_override_once_flag, []() { + SchemaInferenceOverrider::OverrideSchemaInferenceFunctions(); + }); + return S_OK; + } + }; diff --git a/winml/lib/Api.Core/inc/ModelInfo.h b/winml/lib/Api.Core/inc/ModelInfo.h index 95dd8a2988b6b..3546aee5416e5 100644 --- a/winml/lib/Api.Core/inc/ModelInfo.h +++ b/winml/lib/Api.Core/inc/ModelInfo.h @@ -3,27 +3,7 @@ #pragma once -#include "WinMLAdapter.h" - namespace Windows::AI::MachineLearning { -class ModelInfo { - public: - ModelInfo(const onnx::ModelProto* model_proto); - - public: - // model metadata - std::string author_; - std::string name_; - std::string domain_; - std::string description_; - int64_t version_; - std::unordered_map model_metadata_; - wfc::IVector input_features_; - wfc::IVector output_features_; - - private: - void Initialize(const onnx::ModelProto* model_proto); -}; } // namespace Windows::AI::MachineLearning \ No newline at end of file diff --git a/winml/lib/Api.Core/inc/WinMLAdapter.h b/winml/lib/Api.Core/inc/WinMLAdapter.h index 2ef2a3b9beac8..c31ffbd473dd9 100644 --- a/winml/lib/Api.Core/inc/WinMLAdapter.h +++ b/winml/lib/Api.Core/inc/WinMLAdapter.h @@ -4,9 +4,22 @@ #pragma once #include "IOrtSessionBuilder.h" +#include "ModelInfo.h" namespace Windows::AI::MachineLearning::Adapter { +MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") IModelInfo : IUnknown{ + // model metadata + virtual std::string STDMETHODCALLTYPE author() = 0; + virtual std::string STDMETHODCALLTYPE name() = 0; + virtual std::string STDMETHODCALLTYPE domain() = 0; + virtual std::string STDMETHODCALLTYPE description() = 0; + virtual int64_t STDMETHODCALLTYPE version() = 0; + virtual std::unordered_map STDMETHODCALLTYPE model_metadata() = 0; + virtual wfc::IVector STDMETHODCALLTYPE input_features() = 0; + virtual wfc::IVector STDMETHODCALLTYPE output_features() = 0; +}; + MIDL_INTERFACE("eaae30b5-7381-432d-9730-322136b02371") ITensor : IUnknown{ // these all return weak pointers virtual const onnxruntime::Tensor& STDMETHODCALLTYPE get() = 0; @@ -92,14 +105,11 @@ MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown ID3D12CommandQueue* queue, IOrtSessionBuilder** session_builder) = 0; - // factory methods for creating an ort model from a path + // factory methods for creating model protos virtual HRESULT STDMETHODCALLTYPE CreateModelProto(const char* path, IModelProto** model_proto) = 0; - - // factory methods for creating an ort model from a stream virtual HRESULT STDMETHODCALLTYPE CreateModelProto(ABI::Windows::Storage::Streams::IRandomAccessStreamReference* stream_reference, IModelProto** model_proto) = 0; - - // factory methods for creating an ort model from a model_proto virtual HRESULT STDMETHODCALLTYPE CreateModelProto(IModelProto * model_proto_in, IModelProto** model_proto) = 0; + virtual HRESULT STDMETHODCALLTYPE CreateModelInfo(IModelProto * model_proto, IModelInfo ** model_info) = 0; // Data types virtual onnxruntime::MLDataType STDMETHODCALLTYPE GetTensorType() = 0; @@ -142,7 +152,8 @@ MIDL_INTERFACE("b19385e7-d9af-441a-ba7f-3993c7b1c9db") IWinMLAdapter : IUnknown onnxruntime::MLDataType data_type, IOrtValue ** ort_value) = 0; - + // schema overrides (dml does this for us) + virtual HRESULT STDMETHODCALLTYPE OverrideSchemaInferenceFunctions() = 0; }; extern "C" diff --git a/winml/lib/Api/LearningModel.cpp b/winml/lib/Api/LearningModel.cpp index 046bad17c029f..78f821e26e8be 100644 --- a/winml/lib/Api/LearningModel.cpp +++ b/winml/lib/Api/LearningModel.cpp @@ -5,10 +5,8 @@ #include "LearningModel.h" -#include "core/providers/dml/OperatorAuthorHelper/SchemaInferenceOverrider.h" #include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" #include "ModelInfo.h" -#include "PheonixSingleton.h" #include "TelemetryEvent.h" #include "LotusEnvironment.h" @@ -21,21 +19,19 @@ namespace winrt::Windows::AI::MachineLearning::implementation { LearningModel::LearningModel( const hstring& path, const winml::ILearningModelOperatorProvider op_provider) try : LearningModel(WinML::Strings::UTF8FromHString(path), - op_provider) {} + op_provider) { +} WINML_CATCH_ALL LearningModel::LearningModel( const std::string& path, - const winml::ILearningModelOperatorProvider operator_provider) try : lotus_environment_(PheonixSingleton()), - operator_provider_(operator_provider) { + const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { _winmlt::PerformanceTelemetryEvent kLoadModel_event( WinMLRuntimePerf::kLoadModel); - OverrideShapeInferenceMethods(); - - com_ptr<_winmla::IWinMLAdapter> adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - WINML_THROW_IF_FAILED(adapter->CreateModelProto(path.c_str(), model_proto_.put())); + WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); + WINML_THROW_IF_FAILED(adapter_->OverrideSchemaInferenceFunctions()); + WINML_THROW_IF_FAILED(adapter_->CreateModelProto(path.c_str(), model_proto_.put())); Initialize(); @@ -45,19 +41,16 @@ WINML_CATCH_ALL LearningModel::LearningModel( const wss::IRandomAccessStreamReference stream, - const winml::ILearningModelOperatorProvider operator_provider) try : lotus_environment_(PheonixSingleton()), - operator_provider_(operator_provider) { + const winml::ILearningModelOperatorProvider operator_provider) try : operator_provider_(operator_provider) { _winmlt::PerformanceTelemetryEvent kLoadModel_event( WinMLRuntimePerf::kLoadModel); - OverrideShapeInferenceMethods(); - - com_ptr<_winmla::IWinMLAdapter> adapter; - WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put())); - WINML_THROW_IF_FAILED(adapter->CreateModelProto( - static_cast(winrt::get_abi(stream)), + WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter_.put())); + WINML_THROW_IF_FAILED(adapter_->OverrideSchemaInferenceFunctions()); + WINML_THROW_IF_FAILED(adapter_->CreateModelProto( + static_cast(winrt::get_abi(stream)), model_proto_.put())); - + Initialize(); LogCreationEvent(true); @@ -65,8 +58,7 @@ LearningModel::LearningModel( WINML_CATCH_ALL void LearningModel::Initialize() { - model_info_ = std::make_unique( - model_proto_.get()->get()); + WINML_THROW_IF_FAILED(adapter_->CreateModelInfo(model_proto_.get(), model_info_.put())); } void LearningModel::LogCreationEvent(bool fromStream) { @@ -80,13 +72,13 @@ void LearningModel::LogCreationEvent(bool fromStream) { } telemetry_helper.LogModelCreation( fromStream, - model_info_->author_, - model_info_->name_, - model_info_->domain_, - model_info_->description_, - model_info_->version_, + model_info_->author(), + model_info_->name(), + model_info_->domain(), + model_info_->description(), + model_info_->version(), use_fp16, - model_info_->model_metadata_); + model_info_->model_metadata()); } void LearningModel::ModelUseFP16( @@ -119,41 +111,41 @@ void LearningModel::ModelUseFP16( hstring LearningModel::Author() try { - return WinML::Strings::HStringFromUTF8(model_info_->author_); + return WinML::Strings::HStringFromUTF8(model_info_->author()); } WINML_CATCH_ALL hstring LearningModel::Name() try { return WinML::Strings::HStringFromUTF8( - model_info_->name_); + model_info_->name()); } WINML_CATCH_ALL hstring LearningModel::Domain() try { return WinML::Strings::HStringFromUTF8( - model_info_->domain_); + model_info_->domain()); } WINML_CATCH_ALL hstring LearningModel::Description() try { return WinML::Strings::HStringFromUTF8( - model_info_->description_); + model_info_->description()); } WINML_CATCH_ALL int64_t LearningModel::Version() try { - return model_info_->version_; + return model_info_->version(); } WINML_CATCH_ALL wfc::IMapView LearningModel::Metadata() try { std::unordered_map map_copy; - for (auto& pair : model_info_->model_metadata_) { + for (auto& pair : model_info_->model_metadata()) { auto key = WinML::Strings::HStringFromUTF8(pair.first); auto value = WinML::Strings::HStringFromUTF8(pair.second); map_copy.emplace(std::move(key), std::move(value)); @@ -183,13 +175,13 @@ LearningModel::GetOperatorRegistry() { wfc::IVectorView LearningModel::InputFeatures() try { - return model_info_->input_features_.GetView(); + return model_info_->input_features().GetView(); } WINML_CATCH_ALL wfc::IVectorView LearningModel::OutputFeatures() try { - return model_info_->output_features_.GetView(); + return model_info_->output_features().GetView(); } WINML_CATCH_ALL @@ -287,18 +279,6 @@ LearningModel::CopyModelProto() { return model_proto.detach(); } -static std::once_flag g_schema_override_once_flag; - -// Override select shape inference functions which are incomplete in ONNX with versions that are complete, -// and are also used in DML kernel registrations. Doing this avoids kernel and shader creation being -// deferred until first evaluation. It also prevents a situation where inference functions in externally -// registered schema are reachable only after upstream schema have been revised in a later OS release, -// which would be a compatibility risk. -void LearningModel::OverrideShapeInferenceMethods() { - std::call_once(g_schema_override_once_flag, []() { - SchemaInferenceOverrider::OverrideSchemaInferenceFunctions(); - }); -} } // namespace winrt::Windows::AI::MachineLearning::implementation namespace winrt::Windows::AI::MachineLearning::factory_implementation { diff --git a/winml/lib/Api/LearningModel.h b/winml/lib/Api/LearningModel.h index ef08359b6c730..c30057ea5c043 100644 --- a/winml/lib/Api/LearningModel.h +++ b/winml/lib/Api/LearningModel.h @@ -6,11 +6,6 @@ #include "LearningModel.g.h" #include "WinMLAdapter.h" - namespace Windows::AI::MachineLearning { - class LotusEnvironment; - class ModelInfo; -} // namespace Windows::AI::MachineLearning - namespace winrt::Windows::AI::MachineLearning::implementation { struct LearningModel : LearningModelT { @@ -121,13 +116,10 @@ struct LearningModel : LearningModelT { winml::ILearningModelFeatureDescriptor descriptor, bool& use_fp16); - void - OverrideShapeInferenceMethods(); - private: - std::shared_ptr lotus_environment_; + com_ptr<_winmla::IWinMLAdapter> adapter_; com_ptr<_winmla::IModelProto> model_proto_; - std::unique_ptr model_info_; + com_ptr<_winmla::IModelInfo> model_info_; ILearningModelOperatorProvider operator_provider_; }; From f32bbd5cb79b0d11d73ba640baff56c8ea76603d Mon Sep 17 00:00:00 2001 From: Paul McDaniel Date: Fri, 15 Nov 2019 13:15:57 -0800 Subject: [PATCH 2/3] weak ref comment --- winml/lib/Api.Core/inc/WinMLAdapter.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/winml/lib/Api.Core/inc/WinMLAdapter.h b/winml/lib/Api.Core/inc/WinMLAdapter.h index c31ffbd473dd9..3c74a3726ba0c 100644 --- a/winml/lib/Api.Core/inc/WinMLAdapter.h +++ b/winml/lib/Api.Core/inc/WinMLAdapter.h @@ -38,11 +38,12 @@ MIDL_INTERFACE("72aa5eee-100c-4146-9008-4643d3b8af23") IOrtValue : IUnknown{ virtual OrtValue& STDMETHODCALLTYPE get() = 0; virtual onnxruntime::MLDataType STDMETHODCALLTYPE Type() = 0; virtual bool STDMETHODCALLTYPE IsTensor() = 0; -// end + // end virtual HRESULT STDMETHODCALLTYPE GetTensor(ITensor ** tensor) = 0; }; MIDL_INTERFACE("438e7719-554a-4058-84d9-eb6226c34887") IIOBinding : IUnknown{ + // this returns a weak ref virtual onnxruntime::IOBinding* STDMETHODCALLTYPE get() = 0; virtual HRESULT STDMETHODCALLTYPE BindInput(const std::string& name, IOrtValue * ml_value) = 0; virtual HRESULT STDMETHODCALLTYPE BindOutput(const std::string& name, IOrtValue * ml_value) = 0; From 7f9a7f5abef28851a5f1dfb8eab920e6f4eb8a91 Mon Sep 17 00:00:00 2001 From: Paul McDaniel Date: Fri, 15 Nov 2019 16:47:33 -0800 Subject: [PATCH 3/3] added a wrapper for RoGetActivationFactory to hook back into winml for creating winml objects. fixes model load. --- .../lib/Api.Core/FeatureDescriptorFactory.cpp | 92 +++++++++++++++++++ winml/lib/Api.Core/WinMLAdapter.cpp | 2 +- 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/winml/lib/Api.Core/FeatureDescriptorFactory.cpp b/winml/lib/Api.Core/FeatureDescriptorFactory.cpp index 62ea8aba96a9e..cdc5a1b5bcf5a 100644 --- a/winml/lib/Api.Core/FeatureDescriptorFactory.cpp +++ b/winml/lib/Api.Core/FeatureDescriptorFactory.cpp @@ -41,6 +41,98 @@ static const char* c_supported_nominal_ranges[] = "NominalRange_0_255"}; namespace Windows::AI::MachineLearning { + + +// since this code is now running inside ONNXRUNTIME we need to shortcut +// this a bit when creating winrt objects. This will help. + +/* extern "C" +HRESULT __stdcall OS_RoGetActivationFactory(HSTRING classId, GUID const& iid, void** factory) noexcept; + +#ifdef _M_IX86 +#pragma comment(linker, "/alternatename:_OS_RoGetActivationFactory@12=_RoGetActivationFactory@12") +#else +#pragma comment(linker, "/alternatename:OS_RoGetActivationFactory=RoGetActivationFactory") +#endif +*/ + +bool starts_with(std::wstring_view value, std::wstring_view match) noexcept +{ + return 0 == value.compare(0, match.size(), match); +} + +EXTERN_C IMAGE_DOS_HEADER __ImageBase; + +std::wstring GetModulePath() +{ + std::wstring val; + wchar_t modulePath[MAX_PATH] = { 0 }; + GetModuleFileNameW((HINSTANCE)&__ImageBase, modulePath, _countof(modulePath)); + wchar_t drive[_MAX_DRIVE]; + wchar_t dir[_MAX_DIR]; + wchar_t filename[_MAX_FNAME]; + wchar_t ext[_MAX_EXT]; + _wsplitpath_s(modulePath, drive, _MAX_DRIVE, dir, _MAX_DIR, filename, _MAX_FNAME, ext, _MAX_EXT); + + val = drive; + val += dir; + + return val; +} + +extern "C" +int32_t WINRT_CALL WINRT_RoGetActivationFactory(void* classId, winrt::guid const& iid, void** factory) noexcept +{ + *factory = nullptr; + HSTRING classId_hstring = (HSTRING)classId; + std::wstring_view name{ WindowsGetStringRawBuffer(classId_hstring, nullptr), WindowsGetStringLen(classId_hstring) }; + HMODULE library{ nullptr }; + + std::wstring winmlDllPath = GetModulePath() + L"Windows.AI.MachineLearning.dll"; + + if (starts_with(name, L"Windows.AI.MachineLearning.")) + { + const wchar_t* libPath = winmlDllPath.c_str(); + library = LoadLibraryW(libPath); + } + else + { + return RoGetActivationFactory(classId_hstring, iid, factory); + } + + if (!library) + { + return HRESULT_FROM_WIN32(GetLastError()); + } + + using DllGetActivationFactory = HRESULT __stdcall(HSTRING classId, void** factory); + auto call = reinterpret_cast(GetProcAddress(library, "DllGetActivationFactory")); + + if (!call) + { + HRESULT const hr = HRESULT_FROM_WIN32(GetLastError()); + WINRT_VERIFY(FreeLibrary(library)); + return hr; + } + + winrt::com_ptr activation_factory; + HRESULT const hr = call(classId_hstring, activation_factory.put_void()); + + if (FAILED(hr)) + { + WINRT_VERIFY(FreeLibrary(library)); + return hr; + } + + if (winrt::guid(iid) != winrt::guid_of()) + { + return activation_factory->QueryInterface(iid, factory); + } + + *factory = activation_factory.detach(); + return S_OK; +} + // Forward declare CreateFeatureDescriptor static winml::ILearningModelFeatureDescriptor CreateFeatureDescriptor( diff --git a/winml/lib/Api.Core/WinMLAdapter.cpp b/winml/lib/Api.Core/WinMLAdapter.cpp index b40318f381537..a60118d5e466f 100644 --- a/winml/lib/Api.Core/WinMLAdapter.cpp +++ b/winml/lib/Api.Core/WinMLAdapter.cpp @@ -327,7 +327,7 @@ class WinMLAdapter : public Microsoft::WRL::RuntimeClass< auto model_proto_inner = new onnx::ModelProto(); THROW_HR_IF_MSG( E_INVALIDARG, - !model_proto_inner->ParseFromZeroCopyStream(&stream) == false, + model_proto_inner->ParseFromZeroCopyStream(&stream) == false, "The stream failed to parse."); auto model_proto_outer = wil::MakeOrThrow(model_proto_inner);