Skip to content

Commit

Permalink
Layer dev paulm (#2507)
Browse files Browse the repository at this point in the history
* commetns for dml graph transformer
fixed ort value passing using the allocatir info

* fixed and coded maps and sequences across the abi

* cleaned up w4's
cleaned up the model info ABI
delayload directml.dll from winml

* cleaned up namepsace aliases.
renamed _winmla to winmla
this was good PR feedback from tiago a while back.
  • Loading branch information
walrusmcd authored Nov 27, 2019
1 parent 197fd9e commit 301d407
Show file tree
Hide file tree
Showing 21 changed files with 71 additions and 106 deletions.
2 changes: 1 addition & 1 deletion winml/dll/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ extern "C" BOOL WINAPI DllMain(_In_ HINSTANCE hInstance, DWORD dwReason, _In_ vo

extern "C" HRESULT WINAPI MLCreateOperatorRegistry(_COM_Outptr_ IMLOperatorRegistry** registry) try {
*registry = nullptr;
winrt::com_ptr<_winmla::IWinMLAdapter> adapter;
winrt::com_ptr<winmla::IWinMLAdapter> adapter;
WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));
return adapter->GetCustomRegistry(registry);
}
Expand Down
8 changes: 4 additions & 4 deletions winml/lib/Api.Core/CpuOrtSessionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ CpuOrtSessionBuilder::CreateSessionOptions(
HRESULT
CpuOrtSessionBuilder::CreateSession(
OrtSessionOptions* options,
_winmla::IInferenceSession** p_session,
winmla::IInferenceSession** p_session,
onnxruntime::IExecutionProvider** pp_provider) {
RETURN_HR_IF_NULL(E_POINTER, p_session);
RETURN_HR_IF_NULL(E_POINTER, pp_provider);
Expand All @@ -79,15 +79,15 @@ CpuOrtSessionBuilder::CreateSession(
ORT_THROW_IF_ERROR(session->RegisterExecutionProvider(std::move(cpu_provider)));

// assign the session to the out parameter
auto sessionptr = wil::MakeOrThrow<_winmla::InferenceSession>(session.release());
RETURN_IF_FAILED(sessionptr.CopyTo(_uuidof(_winmla::IInferenceSession), (void**)p_session));
auto sessionptr = wil::MakeOrThrow<winmla::InferenceSession>(session.release());
RETURN_IF_FAILED(sessionptr.CopyTo(_uuidof(winmla::IInferenceSession), (void**)p_session));

return S_OK;
}

HRESULT
CpuOrtSessionBuilder::Initialize(
_winmla::IInferenceSession* p_session,
winmla::IInferenceSession* p_session,
onnxruntime::IExecutionProvider* /*p_provider*/
) {
ORT_THROW_IF_ERROR(p_session->get()->Initialize());
Expand Down
6 changes: 3 additions & 3 deletions winml/lib/Api.Core/CpuOrtSessionBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Windows::AI::MachineLearning::Adapter {

class CpuOrtSessionBuilder : public Microsoft::WRL::RuntimeClass <
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
_winmla::IOrtSessionBuilder> {
winmla::IOrtSessionBuilder> {

public:
CpuOrtSessionBuilder();
Expand All @@ -19,11 +19,11 @@ class CpuOrtSessionBuilder : public Microsoft::WRL::RuntimeClass <

HRESULT STDMETHODCALLTYPE CreateSession(
OrtSessionOptions* options,
_winmla::IInferenceSession** p_session,
winmla::IInferenceSession** p_session,
onnxruntime::IExecutionProvider** pp_provider) override;

HRESULT STDMETHODCALLTYPE Initialize(
_winmla::IInferenceSession* p_session,
winmla::IInferenceSession* p_session,
onnxruntime::IExecutionProvider* p_provider) override;
};

Expand Down
8 changes: 4 additions & 4 deletions winml/lib/Api.Core/DmlOrtSessionBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ Microsoft::WRL::ComPtr<IDMLDevice> CreateDmlDevice(ID3D12Device* d3d12Device) {

HRESULT DmlOrtSessionBuilder::CreateSession(
OrtSessionOptions* options,
_winmla::IInferenceSession** p_session,
winmla::IInferenceSession** p_session,
onnxruntime::IExecutionProvider** pp_provider) {
RETURN_HR_IF_NULL(E_POINTER, p_session);
RETURN_HR_IF_NULL(E_POINTER, pp_provider);
Expand All @@ -126,14 +126,14 @@ HRESULT DmlOrtSessionBuilder::CreateSession(
ORT_THROW_IF_ERROR(session->RegisterExecutionProvider(std::move(gpu_provider)));

// assign the session to the out parameter
auto sessionptr = wil::MakeOrThrow<_winmla::InferenceSession>(session.release());
RETURN_IF_FAILED(sessionptr.CopyTo(_uuidof(_winmla::IInferenceSession), (void**)p_session));
auto sessionptr = wil::MakeOrThrow<winmla::InferenceSession>(session.release());
RETURN_IF_FAILED(sessionptr.CopyTo(_uuidof(winmla::IInferenceSession), (void**)p_session));

return S_OK;
}

HRESULT DmlOrtSessionBuilder::Initialize(
_winmla::IInferenceSession* p_session,
winmla::IInferenceSession* p_session,
onnxruntime::IExecutionProvider* p_provider) {
RETURN_HR_IF_NULL(E_INVALIDARG, p_session);
RETURN_HR_IF_NULL(E_INVALIDARG, p_provider);
Expand Down
6 changes: 3 additions & 3 deletions winml/lib/Api.Core/DmlOrtSessionBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Windows::AI::MachineLearning::Adapter {

class DmlOrtSessionBuilder : public Microsoft::WRL::RuntimeClass <
Microsoft::WRL::RuntimeClassFlags<Microsoft::WRL::ClassicCom>,
_winmla::IOrtSessionBuilder> {
winmla::IOrtSessionBuilder> {

public:
DmlOrtSessionBuilder(ID3D12Device* device, ID3D12CommandQueue* queue);
Expand All @@ -19,11 +19,11 @@ class DmlOrtSessionBuilder : public Microsoft::WRL::RuntimeClass <

HRESULT STDMETHODCALLTYPE CreateSession(
OrtSessionOptions* options,
_winmla::IInferenceSession** p_session,
winmla::IInferenceSession** p_session,
onnxruntime::IExecutionProvider** pp_provider) override;

HRESULT STDMETHODCALLTYPE Initialize(
_winmla::IInferenceSession* p_session,
winmla::IInferenceSession* p_session,
onnxruntime::IExecutionProvider* p_provider) override;

private:
Expand Down
2 changes: 1 addition & 1 deletion winml/lib/Api.Core/FeatureDescriptorFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ GetTensorType(
has_unsupported_image_metadata);

if (is_tensor_improperly_annotated_as_image) {
TraceLoggingWrite(_winmla::winml_trace_logging_provider,
TraceLoggingWrite(winmla::winml_trace_logging_provider,
"WinMLInputValidation",
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
TraceLoggingLevel(WINEVENT_LEVEL_WARNING),
Expand Down
14 changes: 7 additions & 7 deletions winml/lib/Api.Core/LotusEnvironment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ void Windows::AI::MachineLearning::CWinMLLogSink::SendImpl(
switch (message.Severity()) {
case (onnxruntime::logging::Severity::kFATAL): //Telemetry
TraceLoggingWrite(
_winmla::winml_trace_logging_provider,
winmla::winml_trace_logging_provider,
"WinMLLogSink",
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
Expand All @@ -29,7 +29,7 @@ void Windows::AI::MachineLearning::CWinMLLogSink::SendImpl(
break;
case (onnxruntime::logging::Severity::kERROR): //Telemetry
TraceLoggingWrite(
_winmla::winml_trace_logging_provider,
winmla::winml_trace_logging_provider,
"WinMLLogSink",
TelemetryPrivacyDataTag(PDT_ProductAndServicePerformance),
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
Expand All @@ -43,7 +43,7 @@ void Windows::AI::MachineLearning::CWinMLLogSink::SendImpl(
break;
case (onnxruntime::logging::Severity::kWARNING):
TraceLoggingWrite(
_winmla::winml_trace_logging_provider,
winmla::winml_trace_logging_provider,
"WinMLLogSink",
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
TraceLoggingLevel(WINEVENT_LEVEL_WARNING),
Expand All @@ -55,7 +55,7 @@ void Windows::AI::MachineLearning::CWinMLLogSink::SendImpl(
break;
case (onnxruntime::logging::Severity::kINFO):
TraceLoggingWrite(
_winmla::winml_trace_logging_provider,
winmla::winml_trace_logging_provider,
"WinMLLogSink",
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
Expand All @@ -69,7 +69,7 @@ void Windows::AI::MachineLearning::CWinMLLogSink::SendImpl(
__fallthrough; //Default is Verbose too.
default:
TraceLoggingWrite(
_winmla::winml_trace_logging_provider,
winmla::winml_trace_logging_provider,
"WinMLLogSink",
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_DEFAULT),
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
Expand All @@ -87,7 +87,7 @@ void Windows::AI::MachineLearning::CWinMLLogSink::SendImpl(
void Windows::AI::MachineLearning::CWinMLLogSink::SendProfileEvent(onnxruntime::profiling::EventRecord& eventRecord) const {
if (eventRecord.cat == onnxruntime::profiling::EventCategory::NODE_EVENT) {
TraceLoggingWrite(
_winmla::winml_trace_logging_provider,
winmla::winml_trace_logging_provider,
"OnnxRuntimeProfiling",
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING),
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
Expand All @@ -102,7 +102,7 @@ void Windows::AI::MachineLearning::CWinMLLogSink::SendProfileEvent(onnxruntime::
TraceLoggingString(eventRecord.args["provider"].c_str(), "Execution Provider"));
} else {
TraceLoggingWrite(
_winmla::winml_trace_logging_provider,
winmla::winml_trace_logging_provider,
"OnnxRuntimeProfiling",
TraceLoggingKeyword(WINML_PROVIDER_KEYWORD_LOTUS_PROFILING),
TraceLoggingLevel(WINEVENT_LEVEL_VERBOSE),
Expand Down
23 changes: 0 additions & 23 deletions winml/lib/Api.Core/inc/WinMLAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,27 +141,4 @@ class InferenceSession : public Microsoft::WRL::RuntimeClass <
std::shared_ptr<onnxruntime::InferenceSession> session_;
};

// header only code to enable smart pointers on abstract ort objects
template <typename T>
class OrtObject {
public:
OrtObject() {
p_ = nullptr;
}

OrtObject(T* m) {
p_ = m;
}

virtual ~OrtObject() {
if (p_ != nullptr) {
ReleaseOrtObject(p_);
}
}
T* get() { return p_; }
private:
T* p_;
};


} // namespace Windows::AI::MachineLearning::Adapter
6 changes: 3 additions & 3 deletions winml/lib/Api/ImageFeatureValue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ static void GPUTensorize(
com_ptr<LearningModelSession> spSession,
void* pAllocatedResource,
WinML::BindingContext& context) {
com_ptr<_winmla::IWinMLAdapter> adapter;
com_ptr<winmla::IWinMLAdapter> adapter;
WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));

auto d3dResource =
Expand Down Expand Up @@ -505,7 +505,7 @@ HRESULT ImageFeatureValue::GetOrtValue(WinML::BindingContext& context, OrtValue*
auto provider = spSession->GetExecutionProvider();

// and the adapter
com_ptr<_winmla::IWinMLAdapter> adapter;
com_ptr<winmla::IWinMLAdapter> adapter;
WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));

// create the OrtValue
Expand Down Expand Up @@ -552,7 +552,7 @@ HRESULT ImageFeatureValue::UpdateSourceResourceData(BindingContext& context, Ort
auto spSession = context.session.as<LearningModelSession>();
auto spDevice = spSession->Device().as<LearningModelDevice>();

com_ptr<_winmla::IWinMLAdapter> adapter;
com_ptr<winmla::IWinMLAdapter> adapter;
WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));

// Get the output tensor raw data
Expand Down
10 changes: 5 additions & 5 deletions winml/lib/Api/LearningModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,9 @@ LearningModel::LoadFromStream(
}
WINML_CATCH_ALL

_winmla::IModelProto*
winmla::IModelProto*
LearningModel::DetachModelProto() {
com_ptr<_winmla::IModelProto> detached_model_proto;
com_ptr<winmla::IModelProto> detached_model_proto;
if (model_proto_ != nullptr) {
detached_model_proto.attach(model_proto_.detach());

Expand All @@ -265,15 +265,15 @@ LearningModel::DetachModelProto() {
return detached_model_proto.detach();
}

_winmla::IModelProto*
winmla::IModelProto*
LearningModel::CopyModelProto() {
if (model_proto_ == nullptr) {
return nullptr;
}

com_ptr<_winmla::IWinMLAdapter> adapter;
com_ptr<winmla::IWinMLAdapter> adapter;
WINML_THROW_IF_FAILED(OrtGetWinMLAdapter(adapter.put()));
com_ptr<_winmla::IModelProto> model_proto;
com_ptr<winmla::IModelProto> model_proto;
WINML_THROW_IF_FAILED(adapter->CreateModelProto(model_proto_.get(), model_proto.put()));

return model_proto.detach();
Expand Down
33 changes: 10 additions & 23 deletions winml/lib/Api/LearningModel.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,35 +91,22 @@ struct LearningModel : LearningModelT<LearningModel> {

public:
/* Non-ABI methods */
bool
IsDisposed();

IMLOperatorRegistry*
GetOperatorRegistry();

_winmla::IModelProto*
DetachModelProto();

_winmla::IModelProto*
CopyModelProto();
bool IsDisposed();
IMLOperatorRegistry* GetOperatorRegistry();
winmla::IModelProto* DetachModelProto();
winmla::IModelProto* CopyModelProto();

private:
void
Initialize();

void
LogCreationEvent(
bool fromStream = false);

void
ModelUseFP16(
void Initialize();
void LogCreationEvent(bool fromStream = false);
void ModelUseFP16(
winml::ILearningModelFeatureDescriptor descriptor,
bool& use_fp16);

private:
com_ptr<_winmla::IWinMLAdapter> adapter_;
com_ptr<_winmla::IModelProto> model_proto_;
com_ptr<_winmla::IModelInfo> model_info_;
com_ptr<winmla::IWinMLAdapter> adapter_;
com_ptr<winmla::IModelProto> model_proto_;
com_ptr<winmla::IModelInfo> model_info_;
ILearningModelOperatorProvider operator_provider_;
};

Expand Down
3 changes: 2 additions & 1 deletion winml/lib/Api/LearningModelBinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,8 @@ void LearningModelBinding::BindUnboundOutputs()

// Add all unbound outputs to binding collection
for (const auto& unbound_output : unbound_output_names) {
WINML_THROW_IF_FAILED(BindOutput(unbound_output, Ort::Value(nullptr)));
Ort::Value out(nullptr);
WINML_THROW_IF_FAILED(BindOutput(unbound_output, out));
}
}

Expand Down
2 changes: 1 addition & 1 deletion winml/lib/Api/LearningModelBinding.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct LearningModelBinding : LearningModelBindingT<LearningModelBinding, ILearn

std::unordered_map<std::string, ProviderInfo> m_providers;

com_ptr<_winmla::IWinMLAdapter> adapter_;
com_ptr<winmla::IWinMLAdapter> adapter_;
std::vector<std::string> input_names_;
std::vector<Ort::Value> inputs_;
std::vector<std::string> output_names_;
Expand Down
Loading

0 comments on commit 301d407

Please sign in to comment.