Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend DML kernels #2641

Merged
merged 9 commits into from
Dec 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,22 @@ namespace winrt::Windows::AI::MachineLearning::implementation
const void* executionHandle,
DmlGraphNodeCreateInfo* graphNodeCreateInfo
)>;

struct GraphNodeFactoryRegistration
{
GraphNodeFactory factory;
std::optional<uint32_t> requiredInputCount;
std::vector<uint32_t> requiredConstantCpuInputs;
bool requiresFloatFormatsExceptConstInputs = false;
};

using GraphNodeFactoryMap = std::unordered_map<onnxruntime::KernelDef*, std::shared_ptr<GraphNodeFactoryRegistration>>;
using KernelSupportQuery = std::function<bool(const onnxruntime::Node& node)>;

struct InternalRegistrationInfo
{
std::vector<uint32_t> requiredConstantCpuInputs;
std::optional<GraphNodeFactoryRegistration> graphNodeFactoryRegistration;
KernelSupportQuery supportQuery;
};

using InternalRegistrationInfoMap = std::unordered_map<onnxruntime::KernelDef*, std::shared_ptr<InternalRegistrationInfo>>;
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace winrt::Windows::AI::MachineLearning::implementation

AbiCustomRegistry::AbiCustomRegistry() :
m_kernelRegistry(std::make_shared<onnxruntime::CustomRegistry>()),
m_graphNodeFactoryMap(std::make_shared<GraphNodeFactoryMap>())
m_internalRegInfoMap(std::make_shared<InternalRegistrationInfoMap>())
{
}

Expand Down Expand Up @@ -321,13 +321,14 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
IMLOperatorKernelFactory* operatorKernelFactory,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer) const noexcept
{
return RegisterOperatorKernel(opKernel, operatorKernelFactory, shapeInferrer, false, false, false);
return RegisterOperatorKernel(opKernel, operatorKernelFactory, shapeInferrer, nullptr, false, false, false);
}

HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
const MLOperatorKernelDescription* opKernel,
IMLOperatorKernelFactory* operatorKernelFactory,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer,
_In_opt_ IMLOperatorSupportQueryPrivate* supportQuery,
bool isInternalOperator,
bool canAliasFirstInput,
bool supportsGraph,
Expand Down Expand Up @@ -449,63 +450,91 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
};

onnxruntime::KernelCreateInfo create_info(builder.Build(), lotusKernelCreateFn);
onnxruntime::KernelDef* kernelDef = create_info.kernel_def.get();

if (supportsGraph)
if (isInternalOperator)
{
auto regInfo = std::make_shared<InternalRegistrationInfo>();
regInfo->requiredConstantCpuInputs = constantCpuInputCapture;

// Only internal operators support usage in DML graphs
if (!isInternalOperator)
if (supportsGraph)
{
THROW_HR(E_INVALIDARG);
GraphNodeFactoryRegistration graphReg;
graphReg.factory =
[kernelFactoryCapture,
requiresInputShapesAtCreation,
requiresOutputShapesAtCreation,
shapeInferrerCapture,
defaultAttributesCapture,
constantCpuInputCapture](const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, DmlGraphNodeCreateInfo* graphNodeCreateInfo)
{
onnxruntime::ProtoHelperNodeContext nodeContext(node);
onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> protoHelper(&nodeContext);

// Use the same list of required constant inputs for the shape inferrer and the kernel.
EdgeShapes outputShapes;
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes);

// Create the kernel while allowing input shape and output shape queries according to options
ComPtr<DmlGraphOpKernelInfoWrapper> kernelInfoWrapper = wil::MakeOrThrow<DmlGraphOpKernelInfoWrapper>(
&protoHelper,
executionHandle,
true,
&outputShapes,
&defaultAttributesCapture,
graphNodeCreateInfo,
constantCpuInputCapture,
constantInputGetter);

Microsoft::WRL::ComPtr<IMLOperatorKernel> kernel;
THROW_IF_FAILED(kernelFactoryCapture->CreateKernel(kernelInfoWrapper.Get(), kernel.GetAddressOf()));
kernelInfoWrapper->Close();
};

if (requiredInputCountForGraph)
{
graphReg.requiredInputCount = *requiredInputCountForGraph;
}

graphReg.requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph;
regInfo->graphNodeFactoryRegistration = graphReg;
}

auto registration = std::make_shared<GraphNodeFactoryRegistration>();
if (supportQuery)
{
ComPtr<IMLOperatorSupportQueryPrivate> supportQueryCapture = supportQuery;

registration->factory =
[kernelFactoryCapture,
requiresInputShapesAtCreation,
requiresOutputShapesAtCreation,
shapeInferrerCapture,
defaultAttributesCapture,
constantCpuInputCapture](const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, DmlGraphNodeCreateInfo* graphNodeCreateInfo)
regInfo->supportQuery = [supportQueryCapture, defaultAttributesCapture](const onnxruntime::Node& node)
{
onnxruntime::ProtoHelperNodeContext nodeContext(node);
onnxruntime::OpNodeProtoHelper<onnxruntime::ProtoHelperNodeContext> protoHelper(&nodeContext);

// Use the same list of required constant inputs for the shape inferrer and the kernel.
EdgeShapes outputShapes;
InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes);


// Create the kernel while allowing input shape and output shape queries according to options
ComPtr<DmlGraphOpKernelInfoWrapper> kernelInfoWrapper = wil::MakeOrThrow<DmlGraphOpKernelInfoWrapper>(
ComPtr<MLSupportQueryContext> supportContext = wil::MakeOrThrow<MLSupportQueryContext>(
&protoHelper,
executionHandle,
true,
&outputShapes,
&defaultAttributesCapture,
graphNodeCreateInfo,
constantCpuInputCapture,
constantInputGetter);

Microsoft::WRL::ComPtr<IMLOperatorKernel> kernel;
THROW_IF_FAILED(kernelFactoryCapture->CreateKernel(kernelInfoWrapper.Get(), kernel.GetAddressOf()));
kernelInfoWrapper->Close();
};
&defaultAttributesCapture);

if (requiredInputCountForGraph)
{
registration->requiredInputCount = *requiredInputCountForGraph;
BOOL bSupported = FALSE;
THROW_IF_FAILED(supportQueryCapture->QuerySupport(supportContext.Get(), &bSupported));
return !!bSupported;
};
}

registration->requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph;
registration->requiredConstantCpuInputs = constantCpuInputCapture;

onnxruntime::KernelDef* kernelDef = create_info.kernel_def.get();
THROW_IF_NOT_OK(m_kernelRegistry->RegisterCustomKernel(create_info));
(*m_graphNodeFactoryMap)[kernelDef] = registration;
(*m_internalRegInfoMap)[kernelDef] = regInfo;
}
else
{
// For backward compatibility, this does not propagate errors
// Currently unsupported for external operators
if (canAliasFirstInput || supportsGraph || requiredInputCountForGraph ||
requiresFloatFormatsForGraph || requiredConstantCpuInputs)
{
THROW_HR(E_INVALIDARG);
}

//
// For backward compatibility, this does not propagate errors for external operators
m_kernelRegistry->RegisterCustomKernel(create_info);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
const MLOperatorKernelDescription* operatorKernel,
IMLOperatorKernelFactory* operatorKernelFactory,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer,
_In_opt_ IMLOperatorSupportQueryPrivate* supportQuery,
bool isInternalOperator,
bool canAliasFirstInput,
bool supportsGraph,
Expand Down Expand Up @@ -79,9 +80,9 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
return m_kernelRegistry;
}

std::shared_ptr<GraphNodeFactoryMap> GetGraphNodeFactoryMap() const
std::shared_ptr<InternalRegistrationInfoMap> GetInternalRegInfoMap() const
{
return m_graphNodeFactoryMap;
return m_internalRegInfoMap;
}

private:
Expand All @@ -104,8 +105,9 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
// (see LotusOpSchemaRegistry::GetSchemaAndHistory).
mutable std::map<std::pair<int, int>, std::shared_ptr<onnxruntime::CustomRegistry>> m_customRegistryOpsetVerMap;

// Map between Lotus KernelDefs and graph node factories used for fusing nodes for graph compilation
mutable std::shared_ptr<GraphNodeFactoryMap> m_graphNodeFactoryMap;
// Map between Lotus KernelDefs and extended data used during partitioning
mutable std::shared_ptr<InternalRegistrationInfoMap> m_internalRegInfoMap;

};

} // namespace winrt::Windows::AI::MachineLearning::implementation
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ namespace Dml

static void CreateDmlKernelRegistry(
_Outptr_ std::shared_ptr<onnxruntime::KernelRegistry>* registry,
_Outptr_ std::shared_ptr<const GraphNodeFactoryMap>* graphNodeFactoryMap)
_Outptr_ std::shared_ptr<const InternalRegistrationInfoMap>* internalRegInfoMap)
{
ComPtr<AbiCustomRegistry> abiRegistry = wil::MakeOrThrow<AbiCustomRegistry>();
Dml::RegisterDmlOperators(abiRegistry.Get());
Expand All @@ -54,7 +54,7 @@ namespace Dml

auto customRegistry = *abiRegistry->GetRegistries().begin();
*registry = customRegistry->GetKernelRegistry();
*graphNodeFactoryMap = abiRegistry->GetGraphNodeFactoryMap();
*internalRegInfoMap = abiRegistry->GetInternalRegInfoMap();
}

ExecutionProvider::ExecutionProvider(
Expand Down Expand Up @@ -176,7 +176,7 @@ namespace Dml
m_cpuInputAllocator = std::make_shared<CPUAllocator>(OrtMemType::OrtMemTypeCPUInput);
m_cpuOutputAllocator = std::make_shared<CPUAllocator>(OrtMemType::OrtMemTypeCPUOutput);

CreateDmlKernelRegistry(&m_kernelRegistry, &m_graphNodeFactoryMap);
CreateDmlKernelRegistry(&m_kernelRegistry, &m_internalRegInfoMap);
}

HRESULT __stdcall ExecutionProviderImpl::GetD3DDevice(_COM_Outptr_ ID3D12Device** d3dDevice) const noexcept
Expand Down Expand Up @@ -493,7 +493,14 @@ namespace Dml
{
std::string partitionKernelPrefix = std::to_string(m_partitionKernelPrefixVal++) + "_";
uint32_t deviceDataTypeMask = GetSuppportedDeviceDataTypeMask();
return PartitionGraph(graph, *m_graphNodeFactoryMap, registries, deviceDataTypeMask, m_kernelRegistry.get(), partitionKernelPrefix);

return PartitionGraph(graph,
*m_internalRegInfoMap,
registries,
deviceDataTypeMask,
m_kernelRegistry.get(),
partitionKernelPrefix
);
}

Status ExecutionProviderImpl::CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ namespace Dml
std::shared_ptr<CPUAllocator> m_cpuInputAllocator;
std::shared_ptr<CPUAllocator> m_cpuOutputAllocator;
std::shared_ptr<onnxruntime::KernelRegistry> m_kernelRegistry;
std::shared_ptr<const winrt::Windows::AI::MachineLearning::implementation::GraphNodeFactoryMap> m_graphNodeFactoryMap;
std::shared_ptr<const winrt::Windows::AI::MachineLearning::implementation::InternalRegistrationInfoMap> m_internalRegInfoMap;
mutable uint64_t m_partitionKernelPrefixVal = 0;

bool m_closed = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ namespace Dml::GraphDescBuilder
const onnxruntime::Node& node = *graph.GetNode(sortedNodeIndex);

const GraphNodeProperties& graphNodeProps = graphNodePropertyMap.find(GetUniqueNodeName(node))->second;
const auto& requiredConstantCpuInputs = graphNodeProps.graphNodeFactoryRegistration->requiredConstantCpuInputs;
const auto& requiredConstantCpuInputs = graphNodeProps.internalRegInfo->requiredConstantCpuInputs;

MLOperatorTensorGetter constantCpuNodeInputGetter = [&node, &constantCpuGraphInputGetter, &requiredConstantCpuInputs](uint32_t inputIndex)
{
Expand All @@ -144,7 +144,7 @@ namespace Dml::GraphDescBuilder
};

DmlGraphNodeCreateInfo graphNodeInfo;
graphNodeProps.graphNodeFactoryRegistration->factory(
graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory(
node,
constantCpuNodeInputGetter,
executionHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ namespace Dml
{
struct GraphNodeProperties
{
std::shared_ptr<const winrt::Windows::AI::MachineLearning::implementation::GraphNodeFactoryRegistration>
graphNodeFactoryRegistration;
std::shared_ptr<const winrt::Windows::AI::MachineLearning::implementation::InternalRegistrationInfo>
internalRegInfo;

// These are currently passed from the partitioning step since the only DML operators current
// supporting graph nodes don't customize the order of edges or shapes, other than coercing
Expand Down
Loading