diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 37b88e891bafc..c34dfbc2d93d6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -94,14 +94,22 @@ namespace winrt::Windows::AI::MachineLearning::implementation const void* executionHandle, DmlGraphNodeCreateInfo* graphNodeCreateInfo )>; - + struct GraphNodeFactoryRegistration { GraphNodeFactory factory; std::optional requiredInputCount; - std::vector requiredConstantCpuInputs; bool requiresFloatFormatsExceptConstInputs = false; }; - using GraphNodeFactoryMap = std::unordered_map>; + using KernelSupportQuery = std::function; + + struct InternalRegistrationInfo + { + std::vector requiredConstantCpuInputs; + std::optional graphNodeFactoryRegistration; + KernelSupportQuery supportQuery; + }; + + using InternalRegistrationInfoMap = std::unordered_map>; } \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index 5365d10c9f03e..76b08ca895095 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -9,7 +9,7 @@ namespace winrt::Windows::AI::MachineLearning::implementation AbiCustomRegistry::AbiCustomRegistry() : m_kernelRegistry(std::make_shared()), - m_graphNodeFactoryMap(std::make_shared()) + m_internalRegInfoMap(std::make_shared()) { } @@ -321,13 +321,14 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( IMLOperatorKernelFactory* operatorKernelFactory, _In_opt_ IMLOperatorShapeInferrer* shapeInferrer) const noexcept { - return RegisterOperatorKernel(opKernel, operatorKernelFactory, shapeInferrer, false, false, false); + return RegisterOperatorKernel(opKernel, operatorKernelFactory, shapeInferrer, nullptr, false, false, false); } HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( const MLOperatorKernelDescription* opKernel, IMLOperatorKernelFactory* operatorKernelFactory, _In_opt_ IMLOperatorShapeInferrer* shapeInferrer, + _In_opt_ IMLOperatorSupportQueryPrivate* supportQuery, bool isInternalOperator, bool canAliasFirstInput, bool supportsGraph, @@ -449,63 +450,91 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( }; onnxruntime::KernelCreateInfo create_info(builder.Build(), lotusKernelCreateFn); + onnxruntime::KernelDef* kernelDef = create_info.kernel_def.get(); - if (supportsGraph) + if (isInternalOperator) { + auto regInfo = std::make_shared(); + regInfo->requiredConstantCpuInputs = constantCpuInputCapture; + // Only internal operators support usage in DML graphs - if (!isInternalOperator) + if (supportsGraph) { - THROW_HR(E_INVALIDARG); + GraphNodeFactoryRegistration graphReg; + graphReg.factory = + [kernelFactoryCapture, + requiresInputShapesAtCreation, + requiresOutputShapesAtCreation, + shapeInferrerCapture, + defaultAttributesCapture, + constantCpuInputCapture](const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, DmlGraphNodeCreateInfo* graphNodeCreateInfo) + { + onnxruntime::ProtoHelperNodeContext nodeContext(node); + onnxruntime::OpNodeProtoHelper protoHelper(&nodeContext); + + // Use the same list of required constant inputs for the shape inferrer and the kernel. + EdgeShapes outputShapes; + InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes); + + // Create the kernel while allowing input shape and output shape queries according to options + ComPtr kernelInfoWrapper = wil::MakeOrThrow( + &protoHelper, + executionHandle, + true, + &outputShapes, + &defaultAttributesCapture, + graphNodeCreateInfo, + constantCpuInputCapture, + constantInputGetter); + + Microsoft::WRL::ComPtr kernel; + THROW_IF_FAILED(kernelFactoryCapture->CreateKernel(kernelInfoWrapper.Get(), kernel.GetAddressOf())); + kernelInfoWrapper->Close(); + }; + + if (requiredInputCountForGraph) + { + graphReg.requiredInputCount = *requiredInputCountForGraph; + } + + graphReg.requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph; + regInfo->graphNodeFactoryRegistration = graphReg; } - auto registration = std::make_shared(); + if (supportQuery) + { + ComPtr supportQueryCapture = supportQuery; - registration->factory = - [kernelFactoryCapture, - requiresInputShapesAtCreation, - requiresOutputShapesAtCreation, - shapeInferrerCapture, - defaultAttributesCapture, - constantCpuInputCapture](const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, DmlGraphNodeCreateInfo* graphNodeCreateInfo) + regInfo->supportQuery = [supportQueryCapture, defaultAttributesCapture](const onnxruntime::Node& node) { onnxruntime::ProtoHelperNodeContext nodeContext(node); onnxruntime::OpNodeProtoHelper protoHelper(&nodeContext); - - // Use the same list of required constant inputs for the shape inferrer and the kernel. - EdgeShapes outputShapes; - InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes); - + // Create the kernel while allowing input shape and output shape queries according to options - ComPtr kernelInfoWrapper = wil::MakeOrThrow( + ComPtr supportContext = wil::MakeOrThrow( &protoHelper, - executionHandle, - true, - &outputShapes, - &defaultAttributesCapture, - graphNodeCreateInfo, - constantCpuInputCapture, - constantInputGetter); - - Microsoft::WRL::ComPtr kernel; - THROW_IF_FAILED(kernelFactoryCapture->CreateKernel(kernelInfoWrapper.Get(), kernel.GetAddressOf())); - kernelInfoWrapper->Close(); - }; + &defaultAttributesCapture); - if (requiredInputCountForGraph) - { - registration->requiredInputCount = *requiredInputCountForGraph; + BOOL bSupported = FALSE; + THROW_IF_FAILED(supportQueryCapture->QuerySupport(supportContext.Get(), &bSupported)); + return !!bSupported; + }; } - registration->requiresFloatFormatsExceptConstInputs = requiresFloatFormatsForGraph; - registration->requiredConstantCpuInputs = constantCpuInputCapture; - - onnxruntime::KernelDef* kernelDef = create_info.kernel_def.get(); THROW_IF_NOT_OK(m_kernelRegistry->RegisterCustomKernel(create_info)); - (*m_graphNodeFactoryMap)[kernelDef] = registration; + (*m_internalRegInfoMap)[kernelDef] = regInfo; } else { - // For backward compatibility, this does not propagate errors + // Currently unsupported for external operators + if (canAliasFirstInput || supportsGraph || requiredInputCountForGraph || + requiresFloatFormatsForGraph || requiredConstantCpuInputs) + { + THROW_HR(E_INVALIDARG); + } + + // + // For backward compatibility, this does not propagate errors for external operators m_kernelRegistry->RegisterCustomKernel(create_info); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h index d3b90538308c1..dee783cd55a7d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h @@ -36,6 +36,7 @@ class AbiCustomRegistry : public WRL::Base GetGraphNodeFactoryMap() const + std::shared_ptr GetInternalRegInfoMap() const { - return m_graphNodeFactoryMap; + return m_internalRegInfoMap; } private: @@ -104,8 +105,9 @@ class AbiCustomRegistry : public WRL::Base, std::shared_ptr> m_customRegistryOpsetVerMap; - // Map between Lotus KernelDefs and graph node factories used for fusing nodes for graph compilation - mutable std::shared_ptr m_graphNodeFactoryMap; + // Map between Lotus KernelDefs and extended data used during partitioning + mutable std::shared_ptr m_internalRegInfoMap; + }; } // namespace winrt::Windows::AI::MachineLearning::implementation diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 5563d9574653f..9352bfdae724d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -45,7 +45,7 @@ namespace Dml static void CreateDmlKernelRegistry( _Outptr_ std::shared_ptr* registry, - _Outptr_ std::shared_ptr* graphNodeFactoryMap) + _Outptr_ std::shared_ptr* internalRegInfoMap) { ComPtr abiRegistry = wil::MakeOrThrow(); Dml::RegisterDmlOperators(abiRegistry.Get()); @@ -54,7 +54,7 @@ namespace Dml auto customRegistry = *abiRegistry->GetRegistries().begin(); *registry = customRegistry->GetKernelRegistry(); - *graphNodeFactoryMap = abiRegistry->GetGraphNodeFactoryMap(); + *internalRegInfoMap = abiRegistry->GetInternalRegInfoMap(); } ExecutionProvider::ExecutionProvider( @@ -176,7 +176,7 @@ namespace Dml m_cpuInputAllocator = std::make_shared(OrtMemType::OrtMemTypeCPUInput); m_cpuOutputAllocator = std::make_shared(OrtMemType::OrtMemTypeCPUOutput); - CreateDmlKernelRegistry(&m_kernelRegistry, &m_graphNodeFactoryMap); + CreateDmlKernelRegistry(&m_kernelRegistry, &m_internalRegInfoMap); } HRESULT __stdcall ExecutionProviderImpl::GetD3DDevice(_COM_Outptr_ ID3D12Device** d3dDevice) const noexcept @@ -493,7 +493,14 @@ namespace Dml { std::string partitionKernelPrefix = std::to_string(m_partitionKernelPrefixVal++) + "_"; uint32_t deviceDataTypeMask = GetSuppportedDeviceDataTypeMask(); - return PartitionGraph(graph, *m_graphNodeFactoryMap, registries, deviceDataTypeMask, m_kernelRegistry.get(), partitionKernelPrefix); + + return PartitionGraph(graph, + *m_internalRegInfoMap, + registries, + deviceDataTypeMask, + m_kernelRegistry.get(), + partitionKernelPrefix + ); } Status ExecutionProviderImpl::CopyTensor(const onnxruntime::Tensor& src, onnxruntime::Tensor& dst) const diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 39e6d2ba01329..da8b405221598 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -172,7 +172,7 @@ namespace Dml std::shared_ptr m_cpuInputAllocator; std::shared_ptr m_cpuOutputAllocator; std::shared_ptr m_kernelRegistry; - std::shared_ptr m_graphNodeFactoryMap; + std::shared_ptr m_internalRegInfoMap; mutable uint64_t m_partitionKernelPrefixVal = 0; bool m_closed = false; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 7a924e0c6ce81..8b6b42c63a70b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -127,7 +127,7 @@ namespace Dml::GraphDescBuilder const onnxruntime::Node& node = *graph.GetNode(sortedNodeIndex); const GraphNodeProperties& graphNodeProps = graphNodePropertyMap.find(GetUniqueNodeName(node))->second; - const auto& requiredConstantCpuInputs = graphNodeProps.graphNodeFactoryRegistration->requiredConstantCpuInputs; + const auto& requiredConstantCpuInputs = graphNodeProps.internalRegInfo->requiredConstantCpuInputs; MLOperatorTensorGetter constantCpuNodeInputGetter = [&node, &constantCpuGraphInputGetter, &requiredConstantCpuInputs](uint32_t inputIndex) { @@ -144,7 +144,7 @@ namespace Dml::GraphDescBuilder }; DmlGraphNodeCreateInfo graphNodeInfo; - graphNodeProps.graphNodeFactoryRegistration->factory( + graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 9d14a07545646..68fc7cc9f513d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -9,8 +9,8 @@ namespace Dml { struct GraphNodeProperties { - std::shared_ptr - graphNodeFactoryRegistration; + std::shared_ptr + internalRegInfo; // These are currently passed from the partitioning step since the only DML operators current // supporting graph nodes don't customize the order of edges or shapes, other than coercing diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 6d1c16420aca9..12973981d75af 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -172,14 +172,14 @@ namespace Dml return true; } - bool NodeTensorTypesSupportedInGraph(const onnxruntime::Node& node, const GraphNodeFactoryRegistration& registration) + bool NodeTensorTypesSupportedInGraph(const onnxruntime::Node& node, const InternalRegistrationInfo& registration) { for (size_t i = 0; i < node.InputDefs().size(); ++i) { bool isConstantCpuInput = std::find(registration.requiredConstantCpuInputs.begin(), registration.requiredConstantCpuInputs.end(), i) != registration.requiredConstantCpuInputs.end(); - if (!isConstantCpuInput && !NodeArgSupportedInGraph(node.InputDefs()[i], registration.requiresFloatFormatsExceptConstInputs)) + if (!isConstantCpuInput && !NodeArgSupportedInGraph(node.InputDefs()[i], registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs)) { return false; } @@ -187,7 +187,7 @@ namespace Dml for (auto arg : node.OutputDefs()) { - if (!NodeArgSupportedInGraph(arg, registration.requiresFloatFormatsExceptConstInputs)) + if (!NodeArgSupportedInGraph(arg, registration.graphNodeFactoryRegistration->requiresFloatFormatsExceptConstInputs)) { return false; } @@ -196,8 +196,31 @@ namespace Dml return true; } + bool TryGetTensorDataType( + const onnxruntime::NodeArg& nodeArg, + _Out_ MLOperatorTensorDataType* onnxElementType + ) + { + *onnxElementType = MLOperatorTensorDataType::Undefined; + + const ::onnx::TypeProto* typeProto = nodeArg.TypeAsProto(); + if (typeProto != nullptr && typeProto->has_tensor_type()) + { + const ::onnx::TypeProto_Tensor& tensorTypeProto = typeProto->tensor_type(); + if (tensorTypeProto.has_elem_type()) + { + *onnxElementType = static_cast(tensorTypeProto.elem_type()); + return true; + } + } + + return false; + } + bool DoesNodeContainSupportedDataTypes( const onnxruntime::Node& node, + const std::unordered_map& nodeNameToPartitionMap, + _In_opt_ const InternalRegistrationInfo* regInfo, uint32_t supportedDeviceDataTypeMask // Each bit corresponds to each DML_TENSOR_DATA_TYPE. ) { @@ -209,33 +232,58 @@ namespace Dml { // Get the tensor element data type for this node, comparing against what the device actually supports. // Use the enumeration from the proto instead of nodeArg.Type() which returns a string. - - const ::onnx::TypeProto* typeProto = nodeArg.TypeAsProto(); - if (typeProto != nullptr && typeProto->has_tensor_type()) + MLOperatorTensorDataType onnxElementType; + if (TryGetTensorDataType(nodeArg, &onnxElementType)) { - const ::onnx::TypeProto_Tensor& tensorTypeProto = typeProto->tensor_type(); - if (tensorTypeProto.has_elem_type()) + DML_TENSOR_DATA_TYPE dmlElementType = GetDmlDataTypeFromMlDataTypeNoThrow(onnxElementType); + if (dmlElementType != DML_TENSOR_DATA_TYPE_UNKNOWN) { - MLOperatorTensorDataType onnxElementType = static_cast(tensorTypeProto.elem_type()); - DML_TENSOR_DATA_TYPE dmlElementType = GetDmlDataTypeFromMlDataTypeNoThrow(onnxElementType); - if (dmlElementType != DML_TENSOR_DATA_TYPE_UNKNOWN) + if (((1 << dmlElementType) & supportedDeviceDataTypeMask) == 0) { - if ((1 << dmlElementType) & supportedDeviceDataTypeMask) - { - // Leave nodeContainsSupportedDataTypes alone, since data type is supported. - return; - } + nodeContainsSupportedDataTypes = false; } } } - - // Else it's not supported (non-tensors, opaque data types, unsupported data types...). - nodeContainsSupportedDataTypes = false; }; // Check whether the node uses any data types which are unsupported by the device. node.ForEachDef(nodeCallback); + // DML kernels support int64 and uint64 are expected to not *introduce* values out of range, which allows + // the temporary trick using strides to emulate 64 bit tensors to work. If the source is a CPU operator, + // graph input or initializer, it's not safe to assume the input can be represented with 32 bits. + if (regInfo) + { + for (uint32_t i = 0; i < node.InputDefs().size(); ++i) + { + const auto* arg = node.InputDefs()[i]; + MLOperatorTensorDataType onnxElementType; + if (arg->Exists() && TryGetTensorDataType(*arg, &onnxElementType)) + { + if (((onnxElementType == MLOperatorTensorDataType::UInt64) || (onnxElementType == MLOperatorTensorDataType::Int64))) + { + // Look up the input partition. If it's a graph input or initializer it will be missing + // from the map. In this case or if the input comes from a CPU partition, it might be + // out of range. + const std::string& argName = arg->Name(); + auto partitionIter = nodeNameToPartitionMap.find(argName); + if (partitionIter == nodeNameToPartitionMap.end() || !partitionIter->second->IsDmlPartition()) + { + // Check if the operator handles the input on the CPU as a constant input + bool isConstantCpuInput = std::find(regInfo->requiredConstantCpuInputs.begin(), regInfo->requiredConstantCpuInputs.end(), i) != + regInfo->requiredConstantCpuInputs.end(); + + if (!isConstantCpuInput) + { + nodeContainsSupportedDataTypes = false; + break; + } + } + } + } + } + } + return nodeContainsSupportedDataTypes; } @@ -245,7 +293,8 @@ namespace Dml const onnxruntime::Node& node, const std::vector& dmlRegistries, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. - const GraphNodeFactoryMap& graphNodeFactoryMap, + const InternalRegistrationInfoMap& internalRegInfoMap, + const std::unordered_map& nodeNameToPartitionMap, _Inout_ std::unordered_map& dmlNodePropertyMap, _Inout_ std::unordered_set& requiredInitializerMap, _Out_ bool* isDmlNode, @@ -260,11 +309,26 @@ namespace Dml for (auto registry : dmlRegistries) { const onnxruntime::KernelCreateInfo* createInfo = registry->TryFindKernel(node, onnxruntime::kDmlExecutionProvider); + if (!createInfo) + { + continue; + } + + auto regInfoIter = internalRegInfoMap.find(createInfo->kernel_def.get()); + std::shared_ptr internalRegInfo; + if (regInfoIter != internalRegInfoMap.end()) + { + internalRegInfo = regInfoIter->second; + if (internalRegInfo->supportQuery && !internalRegInfo->supportQuery(node)) + { + continue; + } + } // Check whether the node uses any data types which are unsupported by the device. - bool nodeContainsSupportedDataTypes = DoesNodeContainSupportedDataTypes(node, supportedDeviceDataTypeMask); + bool nodeContainsSupportedDataTypes = DoesNodeContainSupportedDataTypes(node, nodeNameToPartitionMap, internalRegInfo.get(), supportedDeviceDataTypeMask); - if (createInfo && nodeContainsSupportedDataTypes) + if (nodeContainsSupportedDataTypes) { *isDmlNode = true; @@ -274,12 +338,11 @@ namespace Dml // Ensure that shape information is known statically for the inputs and outputs of the node, // which is required for MLGraph compilation. - auto graphNodeFactorMapIter = graphNodeFactoryMap.find(createInfo->kernel_def.get()); - if (graphNodeFactorMapIter != graphNodeFactoryMap.end() && - NodeTensorTypesSupportedInGraph(node, *graphNodeFactorMapIter->second)) + if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration && + NodeTensorTypesSupportedInGraph(node, *internalRegInfo)) { bool requiredCpuInputsConstant = true; - for (uint32_t inputIndex : graphNodeFactorMapIter->second->requiredConstantCpuInputs) + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) { if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) { @@ -298,7 +361,7 @@ namespace Dml requiredInitializerMap.insert(inputName); } - std::optional requiredInputCount = graphNodeFactorMapIter->second->requiredInputCount; + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; if (requiredCpuInputsConstant && TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes) && @@ -307,7 +370,7 @@ namespace Dml (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) { *isDmlGraphNode = true; - graphNodeProperty.first->second.graphNodeFactoryRegistration = graphNodeFactorMapIter->second; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; } } @@ -550,7 +613,7 @@ namespace Dml std::vector> BuildPartitions( const onnxruntime::GraphViewer& graph, - const GraphNodeFactoryMap& graphNodeFactoryMap, + const InternalRegistrationInfoMap& internalRegInfoMap, const std::vector& registries, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, @@ -610,7 +673,8 @@ namespace Dml node, registries, supportedDeviceDataTypeMask, - graphNodeFactoryMap, + internalRegInfoMap, + nodeNameToPartitionMap, graphNodePropertyMap, requiredInitializerMap, /*out*/ &isDmlNode, @@ -726,7 +790,7 @@ namespace Dml std::vector> PartitionGraph( const onnxruntime::GraphViewer& graph, - const GraphNodeFactoryMap& graphNodeFactoryMap, + const InternalRegistrationInfoMap& internalRegInfoMap, const std::vector& registries, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. onnxruntime::KernelRegistry* registryForPartitionKernels, @@ -741,7 +805,7 @@ namespace Dml std::unordered_map graphNodePropertyMap; std::vector> partitions = BuildPartitions( graph, - graphNodeFactoryMap, + internalRegInfoMap, registries, supportedDeviceDataTypeMask, graphNodePropertyMap, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index 4cc77ab426414..2227a2e67bf0a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -43,7 +43,7 @@ namespace Dml std::vector> BuildPartitions( const onnxruntime::GraphViewer& graph, - const winrt::Windows::AI::MachineLearning::implementation::GraphNodeFactoryMap& graphNodeFactoryMap, + const winrt::Windows::AI::MachineLearning::implementation::InternalRegistrationInfoMap& internalRegInfoMap, const std::vector& registries, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, @@ -53,7 +53,7 @@ namespace Dml std::vector> PartitionGraph( const onnxruntime::GraphViewer& graph, - const winrt::Windows::AI::MachineLearning::implementation::GraphNodeFactoryMap& graphNodeFactoryMap, + const winrt::Windows::AI::MachineLearning::implementation::InternalRegistrationInfoMap& internalRegInfoMap, const std::vector& registries, uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. onnxruntime::KernelRegistry* registryForPartitionKernels, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 59317beb3cedb..b959e0c930755 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1794,6 +1794,13 @@ HRESULT STDMETHODCALLTYPE MLKernelInferenceContext::SetOutputTensorShape( } CATCH_RETURN(); +MLSupportQueryContext::MLSupportQueryContext( + onnxruntime::OpNodeProtoHelper* info, + const AttributeMap* defaultAttributes) : + OpNodeInfoWrapper(info, nullptr, defaultAttributes, gsl::span(), MLOperatorTensorGetter()) +{ +} + bool TryGetStaticShapeIfTensor( const onnx::TypeProto* inputProto, std::vector& shapeDims) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index c23eec22fe824..d6fabb2f7287b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -616,6 +616,21 @@ void InferAndVerifyOutputSizes( const EdgeShapes* inputShapes, EdgeShapes& outputShapes); +class MLSupportQueryContext final : public OpNodeInfoWrapper< + onnxruntime::ProtoHelperNodeContext, + WRL::Base>, + onnxruntime::null_type> +{ + public: + MLSupportQueryContext() = delete; + + MLSupportQueryContext( + onnxruntime::OpNodeProtoHelper* info, + const AttributeMap* defaultAttributes); + + // TODO - ... +}; + onnxruntime::MLDataType ToTensorDataType(::MLOperatorTensorDataType type); std::string ToTypeString(MLOperatorEdgeDescription desc); onnx::AttributeProto_AttributeType ToProto(MLOperatorAttributeType type); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvolution.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvolution.cpp index 990e0981c8172..6a24b9c704301 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvolution.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorConvolution.cpp @@ -14,14 +14,17 @@ class DmlOperatorConvolution : public DmlOperator, public ConvolutionHelperBase DmlOperatorConvolution( const MLOperatorKernelCreationContext& kernelInfo, DML_CONVOLUTION_MODE mode, - DML_CONVOLUTION_DIRECTION direction + DML_CONVOLUTION_DIRECTION direction, + bool hasDynamicPads ) : DmlOperator(kernelInfo), - ConvolutionHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), direction == DML_CONVOLUTION_DIRECTION_BACKWARD) + ConvolutionHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), direction == DML_CONVOLUTION_DIRECTION_BACKWARD, hasDynamicPads) { - ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2); + uint32_t biasIndex = hasDynamicPads ? 3 : 2; + bool hasBiasInput = kernelInfo.GetInputCount() > biasIndex; + + std::vector> kernelInputIndices = { 0, 1, biasIndex }; - std::vector> kernelInputIndices = {0, 1, 2}; DmlOperator::Initialize(kernelInfo, kernelInputIndices); // Vibranium DirectML is limited to handle only 2D and 3D convolution (4D and 5D tensors). So for 1D tensors, @@ -32,7 +35,7 @@ class DmlOperatorConvolution : public DmlOperator, public ConvolutionHelperBase m_inputTensorDescs[1] = CreateTensorDescFromInput(kernelInfo, 1, TensorAxis::DoNotCoerce, TensorAxis::NoPlacementAdjustment, NonspatialDimensionCount, std::nullopt); // Bias is optional so only adjust it if it exists. - if (kernelInfo.GetInputCount() > 2) + if (hasBiasInput) { uint32_t inputDimSize = kernelInfo.GetTensorShapeDescription().GetInputTensorDimensionCount(0); ML_CHECK_VALID_ARGUMENT( @@ -43,9 +46,9 @@ class DmlOperatorConvolution : public DmlOperator, public ConvolutionHelperBase // Resize the bias to be the same dimension as the input tensor. // The 1D tensor needs to be moved to the C channel. - m_inputTensorDescs[2] = CreateTensorDescFromInput( + m_inputTensorDescs[biasIndex] = CreateTensorDescFromInput( kernelInfo, - 2, + biasIndex, TensorAxis::DoNotCoerce, TensorAxis::C, TensorAxis::LeftAligned, @@ -73,7 +76,7 @@ class DmlOperatorConvolution : public DmlOperator, public ConvolutionHelperBase DML_CONVOLUTION_OPERATOR_DESC convDesc = {}; convDesc.InputTensor = &inputDescs[0]; convDesc.FilterTensor = &inputDescs[1]; - convDesc.BiasTensor = kernelInfo.GetInputCount() > 2 ? &inputDescs[2] : nullptr; + convDesc.BiasTensor = hasBiasInput ? &inputDescs[biasIndex] : nullptr; convDesc.OutputTensor = &outputDescs[0]; convDesc.Mode = mode; convDesc.Direction = direction; @@ -92,19 +95,20 @@ class DmlOperatorConvolution : public DmlOperator, public ConvolutionHelperBase }; // A specific type of operation for registration. -template +template class DmlOperatorConvolutionTemplate : public DmlOperatorConvolution { public: DmlOperatorConvolutionTemplate(const MLOperatorKernelCreationContext& kernelInfo) - : DmlOperatorConvolution(kernelInfo, Mode, Direction) + : DmlOperatorConvolution(kernelInfo, Mode, Direction, hasDynamicPads) { } }; -DML_OP_DEFINE_CREATION_FUNCTION(Conv, DmlOperatorConvolutionTemplate); -DML_OP_DEFINE_CREATION_FUNCTION(ConvTranspose, DmlOperatorConvolutionTemplate); -DML_OP_DEFINE_CREATION_FUNCTION(FusedConv, DmlOperatorConvolutionTemplate); -DML_OP_DEFINE_CREATION_FUNCTION(FusedConvTranspose, DmlOperatorConvolutionTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(Conv, DmlOperatorConvolutionTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(ConvTranspose, DmlOperatorConvolutionTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(FusedConv, DmlOperatorConvolutionTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(FusedConvTranspose, DmlOperatorConvolutionTemplate); +DML_OP_DEFINE_CREATION_FUNCTION(ConvTransposeWithDynamicPads, DmlOperatorConvolutionTemplate); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCopy.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCopy.cpp index 0aeb4a2fc0959..eb985f6c404c9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCopy.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCopy.cpp @@ -24,10 +24,7 @@ class DmlOperatorCopy : public DmlOperator // element counts are the same. All this operator does is copy the resource and // rearrange the dimensions, so we tell DML that the output dimensions are the // same as the input dimensions. - m_outputTensorDescs.front() = TensorDesc( - m_outputTensorDescs.front().GetDmlDataType(), - m_inputTensorDescs.front().GetSizes() - ); + m_outputTensorDescs.front() = m_inputTensorDescs.front(); ComPtr contextPrivate; THROW_IF_FAILED(kernelInfo.GetInterface()->QueryInterface(contextPrivate.GetAddressOf())); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp index 9bd86a8d43b94..7f27291ede562 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp @@ -29,6 +29,19 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase assert(m_kernel.spatialDimensionCount <= ARRAYSIZE(m_kernel.windowSize)); + // The below attributes are temporarily not supported: + int ceilMode = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0); + THROW_HR_IF(E_NOTIMPL, ceilMode != 0); + + int storageOrder = kernelInfo.GetOptionalAttribute(AttrName::StorageOrder, 0); + THROW_HR_IF(E_NOTIMPL, storageOrder != 0); + + auto dilations = kernelInfo.GetOptionalAttributeVectorInt32(AttrName::Dilations); + for (int dilation : dilations) + { + THROW_HR_IF(E_NOTIMPL, dilation != 1); + } + // DML requires that DimensionCount be equal to Input.DimCount - 2 for Pooling uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2; if (m_kernel.spatialDimensionCount < expectedSpatialDimCount) @@ -121,6 +134,37 @@ class DmlOperatorPoolingTemplate : public DmlOperatorPooling } }; +void QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool *isSupported) +{ + *isSupported = false; + + MLOperatorAttributes attributes(context); + + // The below attributes are temporarily not supported: + int ceilMode = attributes.GetOptionalAttribute(AttrName::CeilMode, 0); + if (ceilMode != 0) + { + return; + } + + int storageOrder = attributes.GetOptionalAttribute(AttrName::StorageOrder, 0); + if (storageOrder != 0) + { + return; + } + + auto dilations = attributes.GetOptionalAttributeVectorInt32(AttrName::Dilations); + for (int dilation : dilations) + { + if (dilation != 1) + { + return; + } + } + + *isSupported = true; +} + DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(GlobalAveragePool, DmlOperatorPoolingTemplate); DML_OP_DEFINE_CREATION_FUNCTION(MaxPool, DmlOperatorPoolingTemplate); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp index 7ae1995e74cbe..e167a89f0606e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSlice.cpp @@ -6,16 +6,22 @@ namespace Dml { -class DmlOperatorSlice : public DmlOperator, public SliceHelper +class DmlOperatorSlice : public DmlOperator, public SliceHelperBase { public: - DmlOperatorSlice(const MLOperatorKernelCreationContext& kernelInfo) + DmlOperatorSlice(const MLOperatorKernelCreationContext& kernelInfo, uint32_t opsetVersion) : DmlOperator(kernelInfo), - SliceHelper(kernelInfo, kernelInfo.GetTensorShapeDescription()) + SliceHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), opsetVersion) { - ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 1); + uint32_t minInputCount = (opsetVersion < 10) ? 1 : 3; + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= minInputCount); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); - DmlOperator::Initialize(kernelInfo); + + // TODO (23108599): Slice V10 introduces an optional "Steps" input which the kernel does not yet support. + THROW_HR_IF(E_NOTIMPL, kernelInfo.GetInputCount() > 4); + + std::vector> kernelInputIndices = { 0 }; + DmlOperator::Initialize(kernelInfo, kernelInputIndices); assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast(m_offsets.size())); assert(m_inputTensorDescs[0].GetDimensionCount() >= gsl::narrow_cast(m_sizes.size())); @@ -54,6 +60,22 @@ class DmlOperatorSlice : public DmlOperator, public SliceHelper } }; -DML_OP_DEFINE_CREATION_FUNCTION(Slice, DmlOperatorSlice); +// A specific type of operation for registration. +template +class DmlOperatorSliceTemplate : public DmlOperatorSlice +{ +public: + DmlOperatorSliceTemplate(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperatorSlice(kernelInfo, opsetVersion) + { + } +}; + +void QuerySlice(IMLOperatorSupportQueryContextPrivate* context, bool *isSupported) +{ + *isSupported = (context->GetInputCount() <= 4); +} +DML_OP_DEFINE_CREATION_FUNCTION(Slice7, DmlOperatorSliceTemplate<7>); +DML_OP_DEFINE_CREATION_FUNCTION(Slice10, DmlOperatorSliceTemplate<10>); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index b35c570bf8c53..c0ade42ccd5a4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -70,12 +70,14 @@ struct OperatorRegistrationInformation // can't be represented as nodes in an optimized graph yet. std::optional requiredInputCountForDmlGraphSupport; + MLOperatorSupportQueryFunction supportQueryFunction; }; DML_OP_EXTERN_CREATION_FUNCTION(Copy); DML_OP_EXTERN_CREATION_FUNCTION(FC); DML_OP_EXTERN_CREATION_FUNCTION(Conv); DML_OP_EXTERN_CREATION_FUNCTION(ConvTranspose); +DML_OP_EXTERN_CREATION_FUNCTION(ConvTransposeWithDynamicPads); DML_OP_EXTERN_CREATION_FUNCTION(AveragePool); DML_OP_EXTERN_CREATION_FUNCTION(GlobalAveragePool); DML_OP_EXTERN_CREATION_FUNCTION(MaxPool); @@ -97,7 +99,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(Split); DML_OP_EXTERN_CREATION_FUNCTION(Transpose); DML_OP_EXTERN_CREATION_FUNCTION(Tile); DML_OP_EXTERN_CREATION_FUNCTION(Concat); -DML_OP_EXTERN_CREATION_FUNCTION(Slice); +DML_OP_EXTERN_CREATION_FUNCTION(Slice7); +DML_OP_EXTERN_CREATION_FUNCTION(Slice10); DML_OP_EXTERN_CREATION_FUNCTION(Pad); DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth); DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace); @@ -201,6 +204,9 @@ DML_OP_EXTERN_CREATION_FUNCTION(Scatter); DML_OP_EXTERN_CREATION_FUNCTION(Resize); DML_OP_EXTERN_CREATION_FUNCTION(ConstantOfShape); +DML_OP_EXTERN_QUERY_FUNCTION(MaxPool); +DML_OP_EXTERN_QUERY_FUNCTION(Slice); + const static char* const typeNameListDefault[1] = {"T"}; const static char* const typeNameListTopK[2] = { "T", "I" }; const static char* const typeNameListLogicalComparison[2] = { "T", "T1" }; @@ -220,7 +226,7 @@ const static SupportedTensorDataTypes supportedTypeListAllScalars[1] = { Support const static SupportedTensorDataTypes supportedTypeListBool[1] = {SupportedTensorDataTypes::Bool}; const static SupportedTensorDataTypes supportedTypeListTopK[2] = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64}; const static SupportedTensorDataTypes supportedTypeListIndices[1] = { SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; -const static SupportedTensorDataTypes supportedTypeListCast[2] = { SupportedTensorDataTypes::Scalars8to32, SupportedTensorDataTypes::Scalars8to32 }; +const static SupportedTensorDataTypes supportedTypeListCast[2] = { SupportedTensorDataTypes::AllScalars, SupportedTensorDataTypes::Scalars8to32 }; const static SupportedTensorDataTypes supportedTypeListScatterGather[2] = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int64 }; const static SupportedTensorDataTypes supportedTypeListQuantize[2] = { SupportedTensorDataTypes::Float32, SupportedTensorDataTypes::UInt8 }; const static SupportedTensorDataTypes supportedTypeListIsNan[2] = { SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::UInt8 }; @@ -233,6 +239,10 @@ const static SupportedTensorDataTypes supportedTypeListLogicalComparison[2] = /* #define REG_INFO(version, operatorName, ...) \ #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, false, false, ##__VA_ARGS__, +// Versioned operator +#define REG_INFO_VER(version, operatorName, ...) \ + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction, false, false, ##__VA_ARGS__, + // Identity operators use Copy, alias their first input, and require floating point formats // for usage in the graph, besides constant inputs. This is because they currently use // element-wise identity operators in the graph for striding support, but issue actual copies @@ -242,12 +252,17 @@ const static SupportedTensorDataTypes supportedTypeListLogicalComparison[2] = /* // MS-domain operators #define REG_INFO_MS(version, operatorName, ...) \ + #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction, false, false, ##__VA_ARGS__, + +// MS-domain operators +#define REG_INFO_MSDML(version, operatorName, ...) \ #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction, false, false, ##__VA_ARGS__, const static OperatorRegistrationInformation operatorRegistrationInformationTable[] = { /// Domain/Type, Ver, Name, TypeNames, Types, Graph Support, Required const CPU inputs, -/// Input count required for graph support +/// Input count required for graph support, +/// Support query function // Deep Learning Standard Layers {REG_INFO( 7, Conv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, @@ -255,7 +270,9 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO( 8, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, std::nullopt, QueryMaxPool)}, + {REG_INFO( 10, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, std::nullopt, QueryMaxPool)}, + {REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, @@ -269,12 +286,14 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 7, RNN, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)}, {REG_INFO( 7, GRU, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)}, {REG_INFO( 7, LSTM, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::NotSupported)}, + {REG_INFO_MS( 1, ConvTransposeWithDynamicPads, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {2})}, // Data Reorganization Layers {REG_INFO( 7, Split, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, Transpose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, Concat, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO( 7, Slice, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO_VER( 7, Slice, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO_VER( 10, Slice, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {1, 2, 3}, std::nullopt, QuerySlice)}, {REG_INFO( 7, Pad, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, @@ -400,19 +419,19 @@ const static OperatorRegistrationInformation operatorRegistrationInformationTabl {REG_INFO( 9, OneHot, typeNameListOneHot, supportedTypeListOneHot, DmGraphSupport::Supported, {1})}, // Fused operators - {REG_INFO_MS( 1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MS( 1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MS( 1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MS( 1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MS( 1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MS( 1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MS( 1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MS( 1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, - {REG_INFO_MS( 1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, 2)}, + {REG_INFO_MSDML(1, FusedConv, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedConvTranspose, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedInstanceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedBatchNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedMeanVarianceNormalization, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedGemm, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedMatMul, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedAdd, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported)}, + {REG_INFO_MSDML(1, FusedSum, typeNameListDefault, supportedTypeListFloat16to32, DmGraphSupport::Supported, {}, 2)}, // TODO: DwayneR implement MaxUnpool https://dev.azure.com/microsoft/OS/_workitems/edit/21267466 }; - + template MLOperatorEdgeDescription EdgeDesc() { @@ -497,10 +516,17 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) shapeInferrer = wil::MakeOrThrow(information.shapeInferenceFunction); } + ComPtr supportQuery; + if (information.supportQueryFunction) + { + supportQuery = wil::MakeOrThrow(information.supportQueryFunction); + } + THROW_IF_FAILED(registryPrivate->RegisterOperatorKernel( &desc, factory.Get(), shapeInferrer.Get(), + supportQuery.Get(), true, // isInternalOperator information.canAliasFirstInput, // alias kernelSupportsGraph, // supportsGraph diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h index 12ca021b60fc2..186608d78b586 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.h @@ -9,6 +9,7 @@ class MLOperatorKernelCreationContext; // Forward declares an external creation function. #define DML_OP_EXTERN_CREATION_FUNCTION(operatorName) extern void Create##operatorName(IMLOperatorKernelCreationContext* kernelInfo, IMLOperatorKernel** opKernel) +#define DML_OP_EXTERN_QUERY_FUNCTION(operatorName) extern void Query##operatorName(IMLOperatorSupportQueryContextPrivate* context, bool *isSupported); // Declares a callback creation function of the given operator class. // This does not register it, just declares it for usage by registration later. diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 20f4d69cbdde0..dc7def086dad2 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -19,6 +19,7 @@ namespace AttrName static constexpr const char* BlockSize = "blocksize"; static constexpr const char* Border = "border"; static constexpr const char* Broadcast = "broadcast"; + static constexpr const char* CeilMode = "ceil_mode"; static constexpr const char* Clip = "clip"; static constexpr const char* CountIncludePad = "count_include_pad"; static constexpr const char* Dilations = "dilations"; @@ -60,6 +61,7 @@ namespace AttrName static constexpr const char* Split = "split"; static constexpr const char* Starts = "starts"; static constexpr const char* Steepness = "steepness"; + static constexpr const char* StorageOrder = "storage_order"; static constexpr const char* Strides = "strides"; static constexpr const char* Tiles = "tiles"; static constexpr const char* To = "to"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h index 871df4070987f..b15955e0e533d 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h @@ -735,6 +735,7 @@ class MLOperatorKernel : public Microsoft::WRL::RuntimeClass< using MLOperatorTypeInferenceFunction = void (CALLBACK*)(IMLOperatorTypeInferenceContext*); using MLOperatorShapeInferenceFunction = void (CALLBACK*)(IMLOperatorShapeInferenceContext*); using MLOperatorKernelCreateFn = void(*)(IMLOperatorKernelCreationContext*, IMLOperatorKernel**); +using MLOperatorSupportQueryFunction = void (CALLBACK*)(IMLOperatorSupportQueryContextPrivate*, bool*); class MLOperatorShapeInferrer : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, IMLOperatorShapeInferrer> @@ -755,6 +756,29 @@ class MLOperatorShapeInferrer : public Microsoft::WRL::RuntimeClass< MLOperatorShapeInferenceFunction m_shapeInferenceFn = nullptr; }; +class MLOperatorSupportQuery : public Microsoft::WRL::RuntimeClass< + Microsoft::WRL::RuntimeClassFlags, IMLOperatorSupportQueryPrivate> +{ +public: + MLOperatorSupportQuery(MLOperatorSupportQueryFunction queryFn) : + m_queryFn(queryFn) + {} + + HRESULT STDMETHODCALLTYPE QuerySupport( + IMLOperatorSupportQueryContextPrivate* context, + BOOL* isSupported) noexcept override try + { + bool fIsSupported = false; + m_queryFn(context, &fIsSupported); + *isSupported = fIsSupported ? TRUE : FALSE; + return S_OK; + } + CATCH_RETURN(); + +private: + MLOperatorSupportQueryFunction m_queryFn = nullptr; +}; + class MLOperatorTypeInferrer : public Microsoft::WRL::RuntimeClass< Microsoft::WRL::RuntimeClassFlags, IMLOperatorTypeInferrer> { diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index 6b0cbab3fb1fd..216dfdcb2ba3d 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -46,6 +46,60 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex ) const noexcept PURE; }; +//! \interface IMLOperatorAttributes1 +//! \brief Represents the values of an operator's attributes, as determined by a model using the operator. +//! This interface is called by implementations of custom operator kernels, and by implementations +//! of shape and type inferrers. +interface DECLSPEC_UUID("3a798815-dfe3-4bcd-b6a6-f70650d5f80b") DECLSPEC_NOVTABLE +IMLOperatorAttributes1 : public IMLOperatorAttributes +{ + //! Gets an interface pointer for the constant tensor. + //! Note the tensor is CPU side (IsCpuData is true). + STDMETHOD(GetTensorAttribute)( + _In_z_ const char* name, + _COM_Outptr_ IMLOperatorTensor** tensor + ) const noexcept PURE; +}; + +interface __declspec(uuid("897bb586-6cee-4106-8513-dda33151c109")) DECLSPEC_NOVTABLE +IMLOperatorSupportQueryContextPrivate : public IMLOperatorAttributes1 +{ + //! Gets the number of inputs to the operator. + STDMETHOD_(uint32_t, GetInputCount)() const noexcept PURE; + + //! Gets the number of outputs to the operator. + STDMETHOD_(uint32_t, GetOutputCount)() const noexcept PURE; + + //! Returns true if an input to the operator is valid. + //! This always returns true except for optional inputs and invalid indices. + STDMETHOD_(bool, IsInputValid)(uint32_t inputIndex) const noexcept PURE; + + //! Returns true if an output to the operator is valid. + //! This always returns true if within GetOutputCount except for optional outputs. + STDMETHOD_(bool, IsOutputValid)(uint32_t outputIndex) const noexcept PURE; + + //! Gets the description of the specified input edge of the operator. + STDMETHOD(GetInputEdgeDescription)( + uint32_t inputIndex, + _Out_ MLOperatorEdgeDescription* edgeDescription + ) const noexcept PURE; + + //! Gets the description of the specified output edge of the operator. + STDMETHOD(GetOutputEdgeDescription)( + uint32_t outputIndex, + _Out_ MLOperatorEdgeDescription* edgeDescription + ) const noexcept PURE; +}; + +interface __declspec(uuid("023954b3-aed2-4b03-b7c7-f0838053a9a1")) DECLSPEC_NOVTABLE +IMLOperatorSupportQueryPrivate : public IUnknown +{ + STDMETHOD(QuerySupport)( + IMLOperatorSupportQueryContextPrivate* context, + BOOL* isSupported + ) noexcept PURE; +}; + interface DECLSPEC_UUID("3de1dc1e-13e9-4099-ae88-7b4100083415") DECLSPEC_NOVTABLE IMLOperatorRegistryPrivate : public IUnknown { @@ -53,6 +107,7 @@ IMLOperatorRegistryPrivate : public IUnknown const MLOperatorKernelDescription* operatorKernel, IMLOperatorKernelFactory* operatorKernelFactory, _In_opt_ IMLOperatorShapeInferrer* shapeInferrer, + _In_opt_ IMLOperatorSupportQueryPrivate* supportQuery, bool isInternalOperator, bool canAliasFirstInput, bool supportsGraph, @@ -63,21 +118,6 @@ IMLOperatorRegistryPrivate : public IUnknown ) const noexcept PURE; }; -//! \interface IMLOperatorAttributes1 -//! \brief Represents the values of an operator's attributes, as determined by a model using the operator. -//! This interface is called by implementations of custom operator kernels, and by implementations -//! of shape and type inferrers. -interface DECLSPEC_UUID("3a798815-dfe3-4bcd-b6a6-f70650d5f80b") DECLSPEC_NOVTABLE -IMLOperatorAttributes1 : public IMLOperatorAttributes -{ - //! Gets an interface pointer for the constant tensor. - //! Note the tensor is CPU side (IsCpuData is true). - STDMETHOD(GetTensorAttribute)( - _In_z_ const char* name, - _COM_Outptr_ IMLOperatorTensor** tensor - ) const noexcept PURE; -}; - // Declare private enum MLOperatorAttributeType::Tensor. // // enum class MLOperatorAttributeType : uint32_t diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 99425a76f49f0..9f3a0ed7e69f4 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -430,58 +430,7 @@ int64_t ReadAsInt64(MLOperatorTensorDataType tensorDataType, const void* p) return edgeShapes; } - void SliceHelper::Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span inputDimensions - ) - { - const uint32_t dimCount = gsl::narrow_cast(inputDimensions.size()); - - std::vector starts = operatorAttributes.GetOptionalAttributeVectorInt32(AttrName::Starts); - std::vector ends = operatorAttributes.GetOptionalAttributeVectorInt32(AttrName::Ends); - std::vector axes = operatorAttributes.GetOptionalAttributeVectorInt32(AttrName::Axes); - HandleNegativeAxes(/*inout*/ axes, dimCount); - - ML_CHECK_VALID_ARGUMENT(starts.size() == ends.size(), "'starts' must equal 'ends' in size."); - ML_CHECK_VALID_ARGUMENT(axes.empty() || starts.size() == axes.size(), "'axes' must equal 'starts' in size, or 'axes' must be empty."); - - m_outputDimensions.assign(inputDimensions.begin(), inputDimensions.end()); - m_offsets.resize(m_outputDimensions.size()); - m_sizes.resize(m_outputDimensions.size()); - m_strides.resize(m_outputDimensions.size(), 1); // Only a stride of 1 element is supported by ONNX 1.2. - - // Set initial defaults lest 'starts' and 'ends' arrays are shorter than the dimension count. - std::copy(inputDimensions.begin(), inputDimensions.begin() + m_outputDimensions.size(), m_sizes.begin()); - - // Clamp selected dimensions to given 'starts' and 'ends'. - for (int i = 0, ci = gsl::narrow_cast(starts.size()); i < ci; ++i) - { - int dimIndex = i; - if (!axes.empty()) - { - dimIndex = axes[i]; - } - ML_CHECK_VALID_ARGUMENT(dimIndex < inputDimensions.size(), "'axes' must be valid with within actual input dimensions."); - - // Positive values are offsets from 0. - // Negative values are offsets from the dimension's size. - int dim = gsl::narrow_cast(inputDimensions[dimIndex]); - int start = (starts[i] < 0) ? (starts[i] + dim) : starts[i]; - int end = (ends[i] < 0) ? (ends[i] + dim) : ends[i]; - - // Clamp the dimensions to the slice extents. - // Clamp negative numbers to 0, per case test_slice_start_out_of_bounds. - start = std::max(start, 0); - end = std::min(end, dim); - int size = std::max(end - start, 0); - - m_outputDimensions[dimIndex] = size; - m_offsets[dimIndex] = start; - m_sizes[dimIndex] = gsl::narrow_cast(size); - } - } - - std::vector SliceHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + std::vector SliceHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { return { m_outputDimensions }; } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index ad6da14f76469..03c003d2b6d0d 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -299,13 +299,13 @@ class ConvolutionHelperBase { public: enum FilterDims { K }; - enum InputTensor { X, Filter, Bias }; + enum InputTensor { X, Filter}; enum InputDims { N, C, H, W }; public: // Info_t is used to obtain attributes which will be used for calculating the output shape later. template - ConvolutionHelperBase(const Info_t& info, const Shape_t& shape, bool transpose) : + ConvolutionHelperBase(const Info_t& info, const Shape_t& shape, bool transpose, bool hasDynamicPads) : m_kernel(InitializeKernel(info, shape.GetInputTensorDimensionCount(0), shape.GetInputTensorShape(1))) { m_groupCount = info.GetOptionalAttribute(AttrName::Group, 1); @@ -316,7 +316,7 @@ class ConvolutionHelperBase } else { - InitializeKernelAndShapesTransposed(info, shape); + InitializeKernelAndShapesTransposed(info, shape, hasDynamicPads); } } @@ -347,7 +347,7 @@ class ConvolutionHelperBase template - void InitializeKernelAndShapesTransposed(const Info_t& info, const Shape_t& shapeInfo) + void InitializeKernelAndShapesTransposed(const Info_t& info, const Shape_t& shapeInfo, bool hasDynamicPads) { std::vector outputShape = info.GetOptionalAttributeVectorInt32(AttrName::OutputShape); if (!outputShape.empty()) @@ -363,7 +363,33 @@ class ConvolutionHelperBase ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > NonspatialDimensionCount, "Input dimensions must be >= 3"); - ResolvingPadding(inputDimensions); + if (hasDynamicPads) + { + MLOperatorTensor padsTensor = info.GetConstantInputTensor(2); + const std::vector& padsTensorDimensions = padsTensor.GetShape(); + ML_CHECK_VALID_ARGUMENT(padsTensorDimensions.size() == 1, "Pads dimensions must equal 1"); + const size_t dimCount = padsTensorDimensions[0]; + ML_CHECK_VALID_ARGUMENT(dimCount == 2 * NchwSpatialDimensionCount, "Pads count must equal 4"); + const int64_t* padsData = padsTensor.GetData(); + + for (size_t i = 0; i < dimCount; ++i) + { + ML_CHECK_VALID_ARGUMENT(padsData[i] >= 0, "Padding values must be greater than or equal to 0"); + if (i < dimCount / 2) + { + m_kernel.startPadding[i] = gsl::narrow_cast(padsData[i]); + } + else + { + m_kernel.endPadding[i - dimCount/2] = gsl::narrow_cast(padsData[i]); + } + } + } + else + { + ResolvingPadding(inputDimensions); + } + m_outputShapes.resize(1); m_outputShapes[0] = InitializeKernelOutputDimsTranspose(inputDimensions, m_kernel); static_assert(C < NonspatialDimensionCount); @@ -419,14 +445,21 @@ class ConvHelper : public ConvolutionHelperBase { public: template - ConvHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, false) {} + ConvHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, false, false) {} }; class ConvTransposeHelper : public ConvolutionHelperBase { public: template - ConvTransposeHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, true) {} + ConvTransposeHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, true, false) {} +}; + +class ConvTransposeWithDynamicPadsHelper : public ConvolutionHelperBase +{ +public: + template + ConvTransposeWithDynamicPadsHelper(const Info_t& info, const Shape_t& shape) : ConvolutionHelperBase(info, shape, true, true) {} }; class GemmHelper @@ -501,20 +534,138 @@ class SplitHelper std::vector m_split; }; -class SliceHelper +class SliceHelperBase { public: + template + void ReadIndexTensors( + const Info_t& operatorInfo, + std::vector& starts, + std::vector& ends, + std::vector& axes, + std::vector& steps + ) + { + // Get starts, ends, optional axes and optional steps from constant inputs. + MLOperatorTensor startsTensor = operatorInfo.GetConstantInputTensor(1); + const std::vector& startsTensorDimensions = startsTensor.GetShape(); + size_t dimCount = startsTensorDimensions[0]; + const Index_t* startsData = startsTensor.GetData(); + for (size_t i = 0; i < dimCount; ++i) + { + starts.push_back(gsl::narrow_cast(startsData[i])); + } + + MLOperatorTensor endsTensor = operatorInfo.GetConstantInputTensor(2); + const std::vector& endsTensorDimensions = endsTensor.GetShape(); + dimCount = endsTensorDimensions[0]; + const Index_t* endsData = endsTensor.GetData(); + for (size_t i = 0; i < dimCount; ++i) + { + ends.push_back(gsl::narrow_cast(endsData[i])); + } + uint32_t inputCount = operatorInfo.GetInputCount(); + if (operatorInfo.GetInputCount() > 3) + { + MLOperatorTensor axesTensor = operatorInfo.GetConstantInputTensor(3); + const std::vector& axesTensorDimensions = axesTensor.GetShape(); + dimCount = axesTensorDimensions[0]; + const Index_t* axesData = axesTensor.GetData(); + for (size_t i = 0; i < dimCount; ++i) + { + axes.push_back(gsl::narrow_cast(axesData[i])); + } + } + + if (operatorInfo.GetInputCount() > 4) + { + MLOperatorTensor stepsTensor = operatorInfo.GetConstantInputTensor(4); + const std::vector& stepsTensorDimensions = stepsTensor.GetShape(); + dimCount = stepsTensorDimensions[0]; + const Index_t* stepsData = stepsTensor.GetData(); + for (size_t i = 0; i < dimCount; ++i) + { + steps.push_back(gsl::narrow_cast(stepsData[i])); + } + } + } + + template void Initialize( - const MLOperatorAttributes& operatorAttributes, - gsl::span inputDimensions - ); + const Info_t& operatorInfo, + gsl::span inputDimensions, + uint32_t opsetVersion + ) + { + std::vector starts; + std::vector ends; + std::vector axes; + std::vector steps; + if (opsetVersion == 7) + { + // Get starts, ends and axes from attributes + starts = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Starts); + ends = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Ends); + axes = operatorInfo.GetOptionalAttributeVectorInt32(AttrName::Axes); + } + else if (opsetVersion == 10) + { + if (operatorInfo.GetConstantInputTensor(1).GetTensorDataType() == MLOperatorTensorDataType::Int32) + { + ReadIndexTensors(operatorInfo, starts, ends, axes, steps); + } + else + { + THROW_HR_IF(E_INVALIDARG, operatorInfo.GetConstantInputTensor(1).GetTensorDataType() != MLOperatorTensorDataType::Int64); + ReadIndexTensors(operatorInfo, starts, ends, axes, steps); + } + } + + ML_CHECK_VALID_ARGUMENT(starts.size() == ends.size(), "'starts' must equal 'ends' in size."); + ML_CHECK_VALID_ARGUMENT(axes.empty() || starts.size() == axes.size(), "'axes' must equal 'starts' in size, or 'axes' must be empty."); + + m_outputDimensions.assign(inputDimensions.begin(), inputDimensions.end()); + m_offsets.resize(m_outputDimensions.size()); + m_sizes.resize(m_outputDimensions.size()); + m_strides.resize(m_outputDimensions.size(), 1); // Only a stride of 1 element is supported by ONNX 1.2. + + // Set initial defaults lest 'starts' and 'ends' arrays are shorter than the dimension count. + std::copy(inputDimensions.begin(), inputDimensions.begin() + m_outputDimensions.size(), m_sizes.begin()); + + // Clamp selected dimensions to given 'starts' and 'ends'. + for (int i = 0, ci = gsl::narrow_cast(starts.size()); i < ci; ++i) + { + int dimIndex = i; + if (!axes.empty()) + { + dimIndex = axes[i]; + } + ML_CHECK_VALID_ARGUMENT(dimIndex < inputDimensions.size(), "'axes' must be valid with within actual input dimensions."); + + // Positive values are offsets from 0. + // Negative values are offsets from the dimension's size. + int dim = gsl::narrow_cast(inputDimensions[dimIndex]); + int start = (starts[i] < 0) ? (starts[i] + dim) : starts[i]; + int end = (ends[i] < 0) ? (ends[i] + dim) : ends[i]; + + // Clamp the dimensions to the slice extents. + // Clamp negative numbers to 0, per case test_slice_start_out_of_bounds. + start = std::max(start, 0); + end = std::min(end, dim); + int size = std::max(end - start, 0); + + m_outputDimensions[dimIndex] = size; + m_offsets[dimIndex] = start; + m_sizes[dimIndex] = gsl::narrow_cast(size); + } + } // Info_t is used to obtain attributes which will be used for calculating the output shape later. // Shape_t is used to obtain input shape which will be used for adjusting attribute value. template - SliceHelper(const Info_t& info, const Shape_t& shape) + SliceHelperBase(const Info_t& info, const Shape_t& shape, uint32_t opsetVersion) { - Initialize(info, shape.GetInputTensorShape(0)); + Initialize(info, shape.GetInputTensorShape(0), opsetVersion); } std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; @@ -526,6 +677,21 @@ class SliceHelper std::vector m_strides; }; +class SliceHelper : public SliceHelperBase +{ +public: + template + SliceHelper(const Info_t& info, const Shape_t& shape) : SliceHelperBase(info, shape, 7) {} +}; + +class Slice10Helper : public SliceHelperBase +{ +public: + template + Slice10Helper(const Info_t& info, const Shape_t& shape) : SliceHelperBase(info, shape, 10) {} +}; + + class PaddingHelper { public: @@ -1098,6 +1264,7 @@ class OneHotHelper using ShapeInferenceHelper_Conv = ConvHelper; using ShapeInferenceHelper_ConvTranspose = ConvTransposeHelper; +using ShapeInferenceHelper_ConvTransposeWithDynamicPads = ConvTransposeWithDynamicPadsHelper; using ShapeInferenceHelper_AveragePool = PoolingHelper; using ShapeInferenceHelper_GlobalAveragePool = GlobalPoolingHelper; using ShapeInferenceHelper_MaxPool = PoolingHelper; @@ -1120,7 +1287,8 @@ using ShapeInferenceHelper_Flatten = FlattenHelper; using ShapeInferenceHelper_Split = SplitHelper; using ShapeInferenceHelper_Transpose = TransposeHelper; using ShapeInferenceHelper_Concat = ConcatHelper; -using ShapeInferenceHelper_Slice = SliceHelper; +using ShapeInferenceHelper_Slice7 = SliceHelper; +using ShapeInferenceHelper_Slice10 = Slice10Helper; using ShapeInferenceHelper_Pad = PaddingHelper; using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper; using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h index b9e429a6bf1bd..2b28bf9f6209b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorRegistration.h @@ -178,6 +178,7 @@ namespace OperatorHelper static const int sc_sinceVer_Dropout = 10; static const int sc_sinceVer_ThresholdedRelu = 10; static const int sc_sinceVer_Upsample = 10; + static const int sc_sinceVer_Slice = 10; } // namespace OnnxOperatorSet10 namespace MsftOperatorSet1 @@ -193,6 +194,7 @@ namespace OperatorHelper static const int sc_sinceVer_FusedSum = 1; static const int sc_sinceVer_QuantizeLinear = 1; static const int sc_sinceVer_DequantizeLinear = 1; + static const int sc_sinceVer_ConvTransposeWithDynamicPads = 1; } // namespace MsftOperatorSet1 } // namespace OperatorHelper diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index ed444ccc67ae2..d8bd447e099aa 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -563,7 +563,6 @@ int real_main(int argc, char* argv[], Ort::Env& env) { broken_tests.insert({"dynamicquantizelinear_expanded", "Temporarily disabled pending investigation"}); broken_tests.insert({"dynamicquantizelinear_max_adjusted_expanded", "Temporarily disabled pending investigation"}); broken_tests.insert({"dynamicquantizelinear_min_adjusted_expanded", "Temporarily disabled pending investigation"}); - broken_tests.insert({"maxpool_with_argmax_2d_precomputed_pads", "Temporarily disabled pending investigation"}); broken_tests.insert({"mxnet_arcface", "Temporarily disabled pending investigation"}); broken_tests.insert({"yolov3", "Temporarily disabled pending investigation"}); broken_tests.insert({"tf_inception_v2", "Temporarily disabled pending investigation"});