Skip to content

Commit

Permalink
Implement DML execution provider functions from adapter (#2846)
Browse files Browse the repository at this point in the history
* Implement DML execution provider functions from adapter

* Use functions in OnnxruntimeEngine.cpp
  • Loading branch information
ryanlai2 authored Jan 16, 2020
1 parent 272d829 commit 8fb6419
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 10 deletions.
2 changes: 1 addition & 1 deletion cmake/external/protobuf
6 changes: 3 additions & 3 deletions winml/adapter/winml_adapter_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ ORT_API_STATUS(SessionRegisterCustomRegistry, _In_ OrtSession* session, _In_ IML

// Dml methods (TODO need to figure out how these need to move to session somehow...)
// ORT_API_STATUS(DmlExecutionProviderSetDefaultRoundingMode, _In_ const OrtExecutionProvider* dml_provider, _In_ bool is_enabled);
// ORT_API_STATUS(DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider * dml_provider);
// ORT_API_STATUS(DmlExecutionProviderTrimUploadHeap, _In_ const OrtExecutionProvider* dml_provider);
// ORT_API_STATUS(DmlExecutionProviderReleaseCompletedReferences, _In_ const OrtExecutionProvider* dml_provider);
ORT_API_STATUS(DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider * dml_provider);
ORT_API_STATUS(DmlExecutionProviderTrimUploadHeap, _In_ OrtExecutionProvider* dml_provider);
ORT_API_STATUS(DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider);
// ORT_API_STATUS(DmlCreateGPUAllocationFromD3DResource, _In_ ID3D12Resource* pResource, _Out_ void* dml_resource);
// ORT_API_STATUS(DmlFreeGPUAllocation, _In_ void* ptr);

Expand Down
6 changes: 3 additions & 3 deletions winml/adapter/winml_adapter_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = {

// Dml methods (TODO need to figure out how these need to move to session somehow...)
nullptr, //OrtStatus*(ORT_API_CALL* DmlExecutionProviderSetDefaultRoundingMode)(_In_ bool is_enabled)NO_EXCEPTION;
nullptr, // OrtStatus*(ORT_API_CALL* DmlExecutionProviderFlushContext(onnxruntime::IExecutionProvider * dml_provider)NO_EXCEPTION;
nullptr, // OrtStatus*(ORT_API_CALL* DmlExecutionProviderTrimUploadHeap(onnxruntime::IExecutionProvider* dml_provider)NO_EXCEPTION;
nullptr, // OrtStatus*(ORT_API_CALL* DmlExecutionProviderReleaseCompletedReferences(onnxruntime::IExecutionProvider* dml_provider)NO_EXCEPTION;
&winmla::DmlExecutionProviderFlushContext,
&winmla::DmlExecutionProviderTrimUploadHeap,
&winmla::DmlExecutionProviderReleaseCompletedReferences,

nullptr, // OrtStatus*(ORT_API_CALL* DmlCreateGPUAllocationFromD3DResource)(_In_ ID3D12Resource* pResource, _Out_ void* dml_resource)NO_EXCEPTION;
nullptr, // OrtStatus*(ORT_API_CALL* DmlFreeGPUAllocation)(_In_ void* ptr)NO_EXCEPTION;
Expand Down
25 changes: 25 additions & 0 deletions winml/adapter/winml_adapter_dml.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "core/framework/error_code_helper.h"

#include "core/providers/dml/dml_provider_factory.h"
#include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h"

namespace winmla = Windows::AI::MachineLearning::Adapter;

Expand Down Expand Up @@ -45,4 +46,28 @@ ORT_API_STATUS_IMPL(winmla::OrtSessionOptionsAppendExecutionProviderEx_DML, _In_
ID3D12Device* d3d_device, ID3D12CommandQueue* queue) {
auto dml_device = CreateDmlDevice(d3d_device);
return OrtSessionOptionsAppendExecutionProviderEx_DML(options, dml_device.Get(), queue);
}

ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderFlushContext, _In_ OrtExecutionProvider* dml_provider) {
API_IMPL_BEGIN
auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider);
Dml::FlushContext(dml_provider_internal);
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderTrimUploadHeap, _In_ OrtExecutionProvider* dml_provider) {
API_IMPL_BEGIN
auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider);
Dml::TrimUploadHeap(dml_provider_internal);
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(winmla::DmlExecutionProviderReleaseCompletedReferences, _In_ OrtExecutionProvider* dml_provider) {
API_IMPL_BEGIN
auto dml_provider_internal = reinterpret_cast<::onnxruntime::IExecutionProvider*>(dml_provider);
Dml::ReleaseCompletedReferences(dml_provider_internal);
return nullptr;
API_IMPL_END
}
25 changes: 22 additions & 3 deletions winml/lib/Api.Ort/OnnxruntimeEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,34 @@ HRESULT OnnxruntimeEngine::StartProfiling() {
}

HRESULT OnnxruntimeEngine::FlushContext() {
return E_NOTIMPL;
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();

const OrtExecutionProvider* ort_provider;
winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider);

winml_adapter_api->DmlExecutionProviderFlushContext(ort_provider);
return S_OK;
}

HRESULT OnnxruntimeEngine::TrimUploadHeap() {
return E_NOTIMPL;
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();

const OrtExecutionProvider* ort_provider;
winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider);

winml_adapter_api->DmlExecutionProviderTrimUploadHeap(ort_provider);
return S_OK;

}

HRESULT OnnxruntimeEngine::ReleaseCompletedReferences() {
return E_NOTIMPL;
auto winml_adapter_api = engine_factory_->UseWinmlAdapterApi();

const OrtExecutionProvider* ort_provider;
winml_adapter_api->SessionGetExecutionProvider(session_.get(), 0, &ort_provider);

winml_adapter_api->DmlExecutionProviderReleaseCompletedReferences(ort_provider);
return S_OK;
}

HRESULT OnnxruntimeEngine::CopyOneInputAcrossDevices(const char* input_name, const IValue* src, IValue** dest) {
Expand Down

0 comments on commit 8fb6419

Please sign in to comment.