From c713b09901fe61990ea866ff8e801944a6b8c7e8 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Fri, 3 Jun 2022 22:51:26 +0000 Subject: [PATCH 01/20] Register signal ops for op set 17 Note code is mostly being moved, not added. These ops were previously only registered as Microsoft contrib ops and only built if `BUILD_MS_EXPERIMENTAL_OPS=1`. They've been added to the ai.onnx standard op set in version 17. Main components of this change: * Move the kernels from the conrib_ops directory to the core directory. * Add function bodies for ms experimental ops. This will allow old models that use the contrib ops to continue to function. All the function bodies consist of a single op (the new standard op), so performance overhead should be minimal. Minor clean-up also in this change: * De-duplicate get_scalar_value_from_tensor: put it in a new utils.h. * Fix some bugs that caused compilation errors with the experimental ops. Tested with `build.sh --ms_experimental` * Fix some spelling errors and lint violations. * Replace a couple of switch statements with `MLTypeCallDispatcher`. * Use `InlineVector` instead of `std::vector`. Unblocks https://github.com/microsoft/onnxruntime/issues/11640 --- .../contrib_ops/cpu/cpu_contrib_kernels.cc | 20 - .../cpu/signal/window_functions.cc | 334 ------- .../core/graph/signal_ops/signal_defs.cc | 510 ++++------ .../providers/cpu/cpu_execution_provider.cc | 920 ++++++++++-------- .../providers}/cpu/signal/dft.cc | 429 ++++---- .../providers}/cpu/signal/dft.h | 18 +- onnxruntime/core/providers/cpu/signal/utils.h | 31 + .../providers/cpu/signal/window_functions.cc | 215 ++++ .../providers}/cpu/signal/window_functions.h | 12 +- .../testdata/kernel_def_hashes/onnx.cpu.json | 40 +- 10 files changed, 1185 insertions(+), 1344 deletions(-) delete mode 100644 onnxruntime/contrib_ops/cpu/signal/window_functions.cc rename onnxruntime/{contrib_ops => core/providers}/cpu/signal/dft.cc (50%) rename onnxruntime/{contrib_ops => core/providers}/cpu/signal/dft.h (72%) create mode 100644 onnxruntime/core/providers/cpu/signal/utils.h create mode 100644 onnxruntime/core/providers/cpu/signal/window_functions.cc rename onnxruntime/{contrib_ops => core/providers}/cpu/signal/window_functions.h (77%) diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 2068b3c3e3f1f..d89d30b62c737 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -41,16 +41,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FastG class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, NGramRepeatBlock); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, BifurcationDetector); -#ifdef BUILD_MS_EXPERIMENTAL_OPS -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, DFT); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, IDFT); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, HannWindow); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, HammingWindow); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, BlackmanWindow); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, MelWeightMatrix); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSExperimentalDomain, 1, STFT); -#endif - // ******** Start: Quantization ******************* // class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulInteger16); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, QLinearGlobalAveragePool); @@ -224,16 +214,6 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - -#ifdef BUILD_MS_EXPERIMENTAL_OPS - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, -#endif // These ops were experimental ops in onnx domain which have been removed now. We add them here as // contrib ops to main backward compatibility BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/signal/window_functions.cc b/onnxruntime/contrib_ops/cpu/signal/window_functions.cc deleted file mode 100644 index 29256adb264d0..0000000000000 --- a/onnxruntime/contrib_ops/cpu/signal/window_functions.cc +++ /dev/null @@ -1,334 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef BUILD_MS_EXPERIMENTAL_OPS - -#include "core/providers/common.h" -#include "core/framework/op_kernel.h" -#include "core/util/math_cpuonly.h" -#include "Eigen/src/Core/Map.h" -#include "window_functions.h" -#include - -#include "core/platform/threadpool.h" - -#include -#include - -namespace onnxruntime { -namespace contrib { - -ONNX_OPERATOR_KERNEL_EX( - HannWindow, - kMSExperimentalDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().MayInplace(0, 0) - .TypeConstraint("T1", BuildKernelDefConstraints()) - .TypeConstraint("T2", BuildKernelDefConstraints()), - HannWindow); - -ONNX_OPERATOR_KERNEL_EX( - HammingWindow, - kMSExperimentalDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().MayInplace(0, 0) - .TypeConstraint("T1", BuildKernelDefConstraints()) - .TypeConstraint("T2", BuildKernelDefConstraints()), - HammingWindow); - -ONNX_OPERATOR_KERNEL_EX( - BlackmanWindow, - kMSExperimentalDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().MayInplace(0, 0) - .TypeConstraint("T1", BuildKernelDefConstraints()) - .TypeConstraint("T2", BuildKernelDefConstraints()), - BlackmanWindow); - - -ONNX_OPERATOR_KERNEL_EX( - MelWeightMatrix, - kMSExperimentalDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().MayInplace(0, 0) - .TypeConstraint("T1", BuildKernelDefConstraints()) - .TypeConstraint("T2", BuildKernelDefConstraints()) - .TypeConstraint("T3", BuildKernelDefConstraints()), - MelWeightMatrix); - - -template -static Status cosine_sum_window(Tensor* Y, size_t size, float a0, float a1, float a2) { - auto* Y_data = reinterpret_cast(Y->MutableDataRaw()); - - // Calculate the radians to increment per sample - constexpr double pi = 3.14159265; - constexpr double tau = 2 * pi; - const double angular_increment = tau / size; - - for (size_t i = 0; i < size; i++) { - auto a2_component = a2 == 0 ? 0 : (a2 * cos(2 * angular_increment * i)); - - T& value = *(Y_data + i); - value = static_cast(a0 - (a1 * cos(angular_increment * i)) + a2_component); - } - - return Status::OK(); -} - -template -static T get_scalar_value_from_tensor(const Tensor* tensor) { - ORT_ENFORCE(tensor->Shape().Size() == 1, "Tensor input should have a single value."); - auto data_type = tensor->DataType()->AsPrimitiveDataType()->GetDataType(); - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - default: - ORT_THROW("Unsupported input data type of ", data_type); - } -} - -static Status create_cosine_sum_window( - OpKernelContext* ctx, - onnx::TensorProto_DataType output_datatype, - float a0, float a1, float a2) { - - // Get the size of the window - auto size = get_scalar_value_from_tensor(ctx->Input(0)); - - // Get the output tensor - auto Y_shape = onnxruntime::TensorShape({size}); - auto Y = ctx->Output(0, Y_shape); - - switch (output_datatype) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { - ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { - ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT8: { - ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT16: { - ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { - ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { - ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { - ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { - ORT_RETURN_IF_ERROR((cosine_sum_window(Y, size, a0, a1, a2))); - break; - } - default: - ORT_THROW("Unsupported input data type of ", output_datatype); - } - - return Status::OK(); -} - -Status HannWindow::Compute(OpKernelContext* ctx) const { - // HannWindows are a special case of Cosine-Sum Windows which take the following form: - // w[n] = SUM_k=0_K( (-1)^k * a_k * cos(2*pi*k*n/N) ) with values the following values for a_k: - float a0 = .5f; - float a1 = a0; - float a2 = 0; - return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); -} - -Status HammingWindow::Compute(OpKernelContext* ctx) const { - // HammingWindows are a special case of Cosine-Sum Windows which take the following form: - // w[n] = SUM_k=0_K( (-1)^k * a_k * cos(2*pi*k*n/N) ) with values the following values for a_k: - float a0 = 25.f / 46.f; - float a1 = 1 - a0; - float a2 = 0; - return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); -} - -Status BlackmanWindow::Compute(OpKernelContext* ctx) const { - // BlackmanWindows are a special case of Cosine-Sum Windows which take the following form: - // w[n] = SUM_k=0_K( (-1)^k * a_k * cos(2*pi*k*n/N) ) with values the following values for a_k: - float alpha = .16f; - float a2 = alpha / 2.f; - float a0 = .5f - a2; - float a1 = .5f; - return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); -} - -static inline double hz_to_mel_scale(double hz) { - return 2595 * std::log10(1 + hz / 700); -} - -static inline double mel_scale_to_hz(double mels) { - return 700 * (pow(10, (mels / 2595)) - 1); -} - -template -Status create_mel_weight_matrix(OpKernelContext* ctx, int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate, float lower_edge_hertz, float upper_edge_hertz) { - // Determine the width of the spectrogram. - // This is determined as half the size of the fft size. The first element of the spectrum is always retained, - // and the remaining are halved. The second half can be discarded due to the conjugate symmetry of the output with real valued ffts. - // Taken together the formula for the size of the output will be std::floor(dft_length / 2) + 1. - int64_t num_spectrogram_bins = static_cast(std::floor(dft_length / 2 + 1)); - - // Checks - auto lowest_index = std::floor(((dft_length + 1) * lower_edge_hertz) / sample_rate); - auto highest_index = std::floor(((dft_length + 1) * upper_edge_hertz) / sample_rate); - ORT_ENFORCE(lowest_index >= 0 && lowest_index < num_spectrogram_bins, "lower_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the sample_rate."); - ORT_ENFORCE(highest_index >= 0 && highest_index < num_spectrogram_bins, "upper_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the sample_rate."); - - // Create the output shape - onnxruntime::TensorShape output_shape( - { - static_cast(num_spectrogram_bins), - num_mel_bins - }); - auto* Y = ctx->Output(0, output_shape); - - // Get the raw output data - auto* Y_data = reinterpret_cast(Y->MutableDataRaw()); - - // Set the weight matrix to 0 - memset(Y_data, 0, num_spectrogram_bins * num_mel_bins * sizeof(T)); - - // The mel filterbank is a triangular shaped peak with a height of 1 and a base equal to the size of the MEL range divided by - // the number of bins needed times 2. This triagle is then slid across the mel domain linearly, with a constant step size that - // is equal to half of the base of the triange. To accomodate N bins, N+2 data points will be needed to determine the - // start, center and end points of each mel triange filter. - // - // low_frequency where the mel triangle filter banks begin, and they end on the high_frequency_mel - // The range is divided evenly to create the needed points corresponding to the begin, center, end points of each triangle filterbank - std::vector frequency_bins(num_mel_bins + 2); - auto low_frequency_mel = hz_to_mel_scale(lower_edge_hertz); - auto high_frequency_mel = hz_to_mel_scale(upper_edge_hertz); - auto mel_step = (high_frequency_mel - low_frequency_mel) / static_cast(frequency_bins.size()); - - // Convert each point from mel scale back to hertz, and then compute the corresponding index in the fft - for (size_t i = 0; i < frequency_bins.size(); i++) { - auto hz = mel_scale_to_hz(low_frequency_mel + mel_step * i); - frequency_bins[i] = static_cast(std::floor(((dft_length + 1) * hz) / sample_rate)); - } - - for (size_t i = 0; i < static_cast(num_mel_bins); i++) { - auto lower_frequency_value = frequency_bins[i]; //left - auto center_frequency_point = frequency_bins[i+1]; //center - auto higher_frequency_point = frequency_bins[i+2]; //right - - auto low_to_center = center_frequency_point - lower_frequency_value; - if (low_to_center == 0) { - auto& current_element = *(Y_data + (center_frequency_point * num_mel_bins) + i); - current_element = static_cast(1); - } else { - for (size_t j = lower_frequency_value; j <= center_frequency_point; j++) { - auto& current_element = *(Y_data + (j * num_mel_bins) + i); - current_element = static_cast((j - lower_frequency_value) / static_cast(low_to_center)); - } - } - - auto center_to_high = higher_frequency_point - center_frequency_point; - if (center_to_high > 0) { - for (size_t j = center_frequency_point; j < higher_frequency_point; j++) { - auto& current_element = *(Y_data + (j * num_mel_bins) + i); - current_element = static_cast((higher_frequency_point - j) / static_cast(center_to_high)); - } - } - } - - return Status::OK(); -} - -static Status create_mel_weight_matrix(OpKernelContext* ctx, onnx::TensorProto_DataType output_datatype, - int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate, float lower_edge_hertz, float upper_edge_hertz) { - switch (output_datatype) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { - ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { - ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT8: { - ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT16: { - ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { - ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { - ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { - ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); - break; - } - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { - ORT_RETURN_IF_ERROR((create_mel_weight_matrix(ctx, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz))); - break; - } - default: - ORT_THROW("Unsupported input data type of ", output_datatype); - } - return Status::OK(); -} - -Status MelWeightMatrix::Compute(OpKernelContext* ctx) const { - const auto num_mel_bins = get_scalar_value_from_tensor(ctx->Input(0)); - const auto dft_length = get_scalar_value_from_tensor(ctx->Input(1)); - const auto sample_rate = get_scalar_value_from_tensor(ctx->Input(2)); - const auto lower_edge_hertz = get_scalar_value_from_tensor(ctx->Input(3)); - const auto upper_edge_hertz = get_scalar_value_from_tensor(ctx->Input(4)); - - ORT_RETURN_IF_ERROR(create_mel_weight_matrix(ctx, data_type_, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, upper_edge_hertz)); - return Status::OK(); -} - -} // namespace contrib -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/graph/signal_ops/signal_defs.cc b/onnxruntime/core/graph/signal_ops/signal_defs.cc index 27e077c9fefe4..ffca16754319e 100644 --- a/onnxruntime/core/graph/signal_ops/signal_defs.cc +++ b/onnxruntime/core/graph/signal_ops/signal_defs.cc @@ -3,17 +3,21 @@ #ifdef BUILD_MS_EXPERIMENTAL_OPS +#include "core/graph/signal_ops/signal_defs.h" + +#include +#include + #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" #include "core/graph/constants.h" -#include "core/graph/signal_ops/signal_defs.h" #include "core/graph/op.h" #include "onnx/defs/schema.h" #include "onnx/defs/shape_inference.h" #include "onnx/defs/tensor_proto_util.h" -#include - +// NOTE: These were added to the standard op set. We register them under the MS domain +// for backwards compatibility, but new users should use the standard ops instead. Ideally these would be deleted. namespace onnxruntime { namespace signal { @@ -50,7 +54,8 @@ inline const ONNX_NAMESPACE::TensorShapeProto* getOptionalInputShape(ONNX_NAMESP } const auto value_case = input_type->value_case(); - if (value_case != ONNX_NAMESPACE::TypeProto::kTensorType && value_case != ONNX_NAMESPACE::TypeProto::kSparseTensorType) { + if (value_case != ONNX_NAMESPACE::TypeProto::kTensorType && + value_case != ONNX_NAMESPACE::TypeProto::kSparseTensorType) { fail_type_inference("Attribute expected to have tensor or sparse tensor type"); } if (value_case == ONNX_NAMESPACE::TypeProto::kTensorType) { @@ -63,42 +68,28 @@ inline const ONNX_NAMESPACE::TensorShapeProto* getOptionalInputShape(ONNX_NAMESP std::function CosineSumWindowOpDocGenerator(const char* name) { return [name](OpSchema& schema) { std::string doc; - POPULATE_OP_DOC_STR( - doc = R"DOC( + POPULATE_OP_DOC_STR(doc = R"DOC( Generates a {name} window as described in the paper https://ieeexplore.ieee.org/document/1455106. )DOC"; - ReplaceAll(doc, "{name}", name);); + ReplaceAll(doc, "{name}", name);); schema.SetDoc(doc); schema.Attr("output_datatype", "The data type of the output tensor. " "Strictly must be one of the values from DataType enum in TensorProto whose values correspond to T2. " "The default value is 1 = FLOAT. ", - AttributeProto::INT, - static_cast(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)); + AttributeProto::INT, static_cast(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)); schema.Attr("periodic", "If 1, returns a window to be used as periodic function. If 0, return a symmetric window. " - "When 'periodic' is specified, hann computes a window of length size + 1 and returns the first size points. " - "The default value is 1. ", - AttributeProto::INT, - static_cast(1)); - schema.Input(0, - "size", - "A scalar value indicating the length of the window.", - "T1", - OpSchema::Single, - true, - 1, + "When 'periodic' is specified, hann computes a window of length size + 1 and returns the first size " + "points. The default value is 1. ", + AttributeProto::INT, static_cast(1)); + schema.Input(0, "size", "A scalar value indicating the length of the window.", "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable); - schema.Output(0, - "output", + schema.Output(0, "output", "A Hann window with length: size. " "The output has the shape: [size].", - "T2", - OpSchema::Single, - true, - 1, - OpSchema::NonDifferentiable); + "T2", OpSchema::Single, true, 1, OpSchema::NonDifferentiable); schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Update the output data type to the output_datatype auto output_datatype = getAttribute(ctx, "output_datatype", @@ -132,66 +123,93 @@ Generates a {name} window as described in the paper https://ieeexplore.ieee.org/ }; } +ONNX_NAMESPACE::NodeProto WindowOpFunctionNode(const char* name) { + ONNX_NAMESPACE::NodeProto node; + node.set_op_type(std::string(name) + "Window"); + + auto* output_datatype = node.add_attribute(); + output_datatype->set_name("output_datatype"); + output_datatype->set_ref_attr_name("output_datatype"); + + auto* periodic = node.add_attribute(); + periodic->set_name("periodic"); + periodic->set_ref_attr_name("periodic"); + + node.add_input("size"); + node.add_output("output"); + return node; +} + void RegisterSignalSchemas() { + ONNX_NAMESPACE::NodeProto dft_function_node; + dft_function_node.set_op_type("DFT"); + + auto* dft_function_onesided = dft_function_node.add_attribute(); + dft_function_onesided->set_name("onesided"); + dft_function_onesided->set_ref_attr_name("onesided"); + + auto* dft_function_axis = dft_function_node.add_attribute(); + dft_function_axis->set_name("axis"); + dft_function_axis->set_ref_attr_name("axis"); + + auto* dft_function_inverse = dft_function_node.add_attribute(); + dft_function_inverse->set_name("inverse"); + dft_function_inverse->set_ref_attr_name("inverse"); + + dft_function_node.add_input("input"); + dft_function_node.add_input("dft_length"); + dft_function_node.add_output("output"); + + ONNX_NAMESPACE::OperatorSetIdProto onnx_op_set_17; + onnx_op_set_17.set_domain(kOnnxDomain); + onnx_op_set_17.set_version(17); + MS_SIGNAL_OPERATOR_SCHEMA(DFT) .SetDomain(kMSExperimentalDomain) .SinceVersion(1) .SetDoc(R"DOC(DFT)DOC") .Attr("onesided", - "If True (default), only values for half of the fft size are returned because the real-to-complex Fourier transform satisfies the conjugate symmetry." - "The output tensor will return the first floor(n_fft/2) + 1 values from the DFT." - "Values can be 0 or 1.", - AttributeProto::AttributeType::AttributeProto_AttributeType_INT, - static_cast(0)) + "If True (default), only values for half of the fft size are returned because the real-to-complex Fourier " + "transform satisfies the conjugate symmetry.The output tensor will return the first floor(n_fft/2) + 1 " + "values from the DFT. Values can be 0 or 1.", + AttributeProto::AttributeType::AttributeProto_AttributeType_INT, static_cast(0)) .Attr("axis", - "The axis on which to perform the DFT. By default this value is set to 0, which corresponds to the first dimension after the batch index." - "This value must be less than signal_dimN, where signal_dimN is the number of dimensions in the signal.", - AttributeProto::AttributeType::AttributeProto_AttributeType_INT, - static_cast(0)) + "The axis on which to perform the DFT. By default this value is set to 0, which corresponds to the first " + "dimension after the batch index. This value must be less than signal_dimN, where signal_dimN is the " + "number of dimensions in the signal.", + AttributeProto::AttributeType::AttributeProto_AttributeType_INT, static_cast(0)) .Attr("inverse", - "Whether to perform the inverse discrete fourier transform. By default this value is set to 0, which corresponds to false.", - AttributeProto::INT, - static_cast(0)) - .Input(0, - "input", - "For real input, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][1]. " - "For complex input, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]. " - "The first dimension is the batch dimension. " - "The following N dimentions correspond to the signal's dimensions. " + "Whether to perform the inverse discrete fourier transform. By default this value is set to 0, which " + "corresponds to false.", + AttributeProto::INT, static_cast(0)) + .Input(0, "input", + "For real input, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]..." + "[signal_dimN][1]. For complex input, the following shape is expected: " + "[batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]. The first dimension is the batch dimension. " + "The following N dimensions correspond to the signal's dimensions. " "The final dimension represents the real and imaginary parts of the value in that order.", - "T1", - OpSchema::Single, - true, - 1, - OpSchema::NonDifferentiable) - .Input(1, - "dft_length", + "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) + .Input(1, "dft_length", "The length of the signal." "If greater than the axis dimension, the signal will be zero-padded up to dft_length. " "If less than the axis dimension, only the first dft_length values will be used as the signal. " "It's an optional value. ", - "T2", - OpSchema::Optional, - true, - 1, - OpSchema::NonDifferentiable) - .Output(0, - "output", + "T2", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) + .Output(0, "output", "The Fourier Transform of the input vector." - "If onesided is 0, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]. " - "If axis=0 and onesided is 1, the following shape is expected: [batch_idx][floor(signal_dim1/2)+1][signal_dim2]...[signal_dimN][2]. " - "If axis=1 and onesided is 1, the following shape is expected: [batch_idx][signal_dim1][floor(signal_dim2/2)+1]...[signal_dimN][2]. " - "If axis=N-1 and onesided is 1, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[floor(signal_dimN/2)+1][2]. " + "If onesided is 0, the following shape is expected: " + "[batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]. " + "If axis=0 and onesided is 1, the following shape is expected: " + "[batch_idx][floor(signal_dim1/2)+1][signal_dim2]...[signal_dimN][2]. " + "If axis=1 and onesided is 1, the following shape is expected: " + "[batch_idx][signal_dim1][floor(signal_dim2/2)+1]...[signal_dimN][2]. " + "If axis=N-1 and onesided is 1, the following shape is expected: " + "[batch_idx][signal_dim1][signal_dim2]...[floor(signal_dimN/2)+1][2]. " "The signal_dim at the specified axis is equal to the dft_length.", "T1") - .TypeConstraint( - "T1", - {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, - "Constrain input and output types to float tensors.") - .TypeConstraint( - "T2", - {"tensor(int32)", "tensor(int64)"}, - "Constrain scalar length types to int64_t.") + .TypeConstraint("T1", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "Constrain scalar length types to int64_t.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { bool is_onesided = static_cast(getAttribute(ctx, "onesided", 0)); bool inverse = static_cast(getAttribute(ctx, "inverse", 0)); @@ -216,11 +234,7 @@ void RegisterSignalSchemas() { auto rank = input_shape.dim_size(); if (!(-rank <= axis && axis < rank)) { - fail_shape_inference( - "axis attribute value ", - axis, - " is invalid for a tensor of rank ", - rank); + fail_shape_inference("axis attribute value ", axis, " is invalid for a tensor of rank ", rank); } auto axis_idx = (axis >= 0 ? axis : axis + rank); @@ -275,58 +289,58 @@ void RegisterSignalSchemas() { } updateOutputShape(ctx, 0, result_shape_proto); - }); + }) + .FunctionBody({dft_function_node}, {onnx_op_set_17}); + + ONNX_NAMESPACE::NodeProto idft_function_node; + idft_function_node.set_op_type("DFT"); + + auto* idft_function_inverse = idft_function_node.add_attribute(); + idft_function_inverse->set_name("inverse"); + idft_function_inverse->set_i(1); + + auto* idft_function_axis = idft_function_node.add_attribute(); + idft_function_axis->set_name("axis"); + idft_function_axis->set_ref_attr_name("axis"); + + idft_function_node.add_input("input"); + idft_function_node.add_input("dft_length"); + idft_function_node.add_output("output"); MS_SIGNAL_OPERATOR_SCHEMA(IDFT) .SetDomain(kMSExperimentalDomain) .SinceVersion(1) .SetDoc(R"DOC(IDFT)DOC") .Attr("axis", - "The axis on which to perform the DFT. By default this value is set to 0, which corresponds to the first dimension after the batch index." + "The axis on which to perform the DFT. By default this value is set to 0, which corresponds to the first " + "dimension after the batch index." "This value must be less than signal_dimN, where signal_dimN is the number of dimensions in the signal.", - AttributeProto::AttributeType::AttributeProto_AttributeType_INT, - static_cast(0)) - .Input(0, - "input", - "For real multi-dimensional input, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][1]." - "For complex multi-dimensional input, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]." + AttributeProto::AttributeType::AttributeProto_AttributeType_INT, static_cast(0)) + .Input(0, "input", + "For real multi-dimensional input, the following shape is expected: " + "[batch_idx][signal_dim1][signal_dim2]...[signal_dimN][1]." + "For complex multi-dimensional input, the following shape is expected: " + "[batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]." "The first dimension is the batch dimension." "The final dimension represents the real and imaginary parts of the value.", "T1") - .Input(1, - "dft_length", + .Input(1, "dft_length", "The length of the signal." "If greater than the axis dimension, the signal will be zero-padded up to dft_length. " "If less than the axis dimension, only the first dft_length values will be used as the signal. " "It's an optional value. ", - "T2", - OpSchema::Optional, - true, - 1, - OpSchema::NonDifferentiable) - .Output(0, - "output", + "T2", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) + .Output(0, "output", "The inverse discrete Fourier transform of the input. " "The signal_dim at the specified axis is equal to the dft_length." "The expected shape is [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]" "For all types of input, the last dimension of the output represents the components of a complex number.", - "T1", - OpSchema::Single, - true, - 1, - OpSchema::NonDifferentiable) - .TypeConstraint( - "T1", - {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, - "Constrain input and output types to float tensors.") - .TypeConstraint( - "T2", - {"tensor(int64)"}, - "Constrain scalar length types to int64_t.") + "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) + .TypeConstraint("T1", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .TypeConstraint("T2", {"tensor(int64)"}, "Constrain scalar length types to int64_t.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); - const int64_t batch_ndim = 1; - auto& input_shape = getInputShape(ctx, 0); ONNX_NAMESPACE::TensorShapeProto result_shape = input_shape; auto dim_size = static_cast(input_shape.dim_size()); @@ -339,84 +353,62 @@ void RegisterSignalSchemas() { } updateOutputShape(ctx, 0, result_shape); - }); + }) + .FunctionBody({idft_function_node}, {onnx_op_set_17}); + + ONNX_NAMESPACE::NodeProto stft_function_node; + stft_function_node.set_op_type("STFT"); + + auto* stft_function_onesided = idft_function_node.add_attribute(); + stft_function_onesided->set_name("onesided"); + stft_function_onesided->set_ref_attr_name("onesided"); + + stft_function_node.add_input("signal"); + stft_function_node.add_input("frame_step"); + stft_function_node.add_input("window"); + stft_function_node.add_input("frame_length"); + stft_function_node.add_output("output"); MS_SIGNAL_OPERATOR_SCHEMA(STFT) .SetDomain(kMSExperimentalDomain) .SinceVersion(1) .SetDoc(R"DOC(STFT)DOC") - .Attr( - "onesided", - "If onesided is 1, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] are returned because " - "the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] = X[m,w] = " - "X[m,n_fft-w]*. Note if the input or window tensors are complex, then onesided output is not possible. " - "Enabling onesided with real inputs performs a Real-valued fast Fourier transform (RFFT)." - "When invoked with real or complex valued input, the default value is 1. " - "Values can be 0 or 1.", - AttributeProto::INT, - static_cast(1)) - .Input(0, - "signal", + .Attr("onesided", + "If onesided is 1, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] are returned because " + "the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] = X[m,w] = " + "X[m,n_fft-w]*. Note if the input or window tensors are complex, then onesided output is not possible. " + "Enabling onesided with real inputs performs a Real-valued fast Fourier transform (RFFT)." + "When invoked with real or complex valued input, the default value is 1. " + "Values can be 0 or 1.", + AttributeProto::INT, static_cast(1)) + .Input(0, "signal", "Input tensor representing a real or complex valued signal. " "For real input, the following shape is expected: [batch_size][signal_length][1]. " "For complex input, the following shape is expected: [batch_size][signal_length][2], where " "[batch_size][signal_length][0] represents the real component and [batch_size][signal_length][1] " "represents the imaginary component of the signal.", - "T1", - OpSchema::Single, - true, - 1, + "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) + .Input(1, "frame_step", "The number of samples to step between successive DFTs.", "T2", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) - .Input(1, - "frame_step", - "The number of samples to step between successive DFTs.", - "T2", - OpSchema::Single, - true, - 1, - OpSchema::NonDifferentiable) - .Input(2, - "window", + .Input(2, "window", "A tensor representing the window that will be slid over the signal." "The window must have rank 1 with shape: [window_shape]. " "It's an optional value. ", - "T1", - OpSchema::Optional, - true, - 1, - OpSchema::NonDifferentiable) - .Input(3, - "frame_length", + "T1", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) + .Input(3, "frame_length", "A scalar representing the size of the DFT. " "It's an optional value.", - "T2", - OpSchema::Optional, - true, - 1, - OpSchema::NonDifferentiable) - .Output(0, - "output", + "T2", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) + .Output(0, "output", "The Short-time Fourier Transform of the signals." "If onesided is 1, the output has the shape: [batch_size][frames][dft_unique_bins][2], where " "dft_unique_bins is frame_length // 2 + 1 (the unique components of the DFT) " "If onesided is 0, the output has the shape: [batch_size][frames][frame_length][2], where frame_length " "is the length of the DFT.", - "T1", - OpSchema::Single, - true, - 1, - OpSchema::NonDifferentiable) - .TypeConstraint( - "T1", - {"tensor(float)", - "tensor(float16)", - "tensor(double)", - "tensor(bfloat16)"}, - "Constrain signal and output to float tensors.") - .TypeConstraint( - "T2", - {"tensor(int32)", "tensor(int64)"}, - "Constrain scalar length types to int64_t.") + "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain signal and output to float tensors.") + .TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "Constrain scalar length types to int64_t.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); @@ -461,7 +453,7 @@ void RegisterSignalSchemas() { const ONNX_NAMESPACE::TensorShapeProto* window_shape = nullptr; if (ctx.getNumInputs() >= 3) { - window_shape = getOptionalInputShape(ctx, 2); + window_shape = ONNX_NAMESPACE::getOptionalInputShape(ctx, 2); } else { window_shape = nullptr; } @@ -523,122 +515,55 @@ void RegisterSignalSchemas() { result_shape_proto.add_dim()->set_dim_value(dft_size); result_shape_proto.add_dim()->set_dim_value(2); updateOutputShape(ctx, 0, result_shape_proto); - }); + }) + .FunctionBody({stft_function_node}, {onnx_op_set_17}); // Window Functions MS_SIGNAL_OPERATOR_SCHEMA(HannWindow) .SetDomain(kMSExperimentalDomain) .SinceVersion(1) .FillUsing(CosineSumWindowOpDocGenerator("Hann")) - .TypeConstraint( - "T1", - {"tensor(int32)", "tensor(int64)"}, - "Constrain the input size to int64_t.") - .TypeConstraint( - "T2", - ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), - "Constrain output types to numeric tensors.") - .FunctionBody(R"ONNX( - { - A0 = Constant () - A1 = Constant () - A2 = Constant () - Zero = Constant () - One = Constant () - Two = Constant () - Tau = Constant () - Size_FP = Cast (size) - AngularIncrement = Div (Tau, Size_FP) - Range = Range (Zero, Size_FP, One) - RangeAngular = Mul (Range, AngularIncrement) - TwoRangeAngular = Mul (RangeAngular, Two) - CosTwoRangeAngular = Cos (TwoRangeAngular) - A2_Component = Mul (A2, CosTwoRangeAngular) - CosRangeAngular = Cos (RangeAngular) - A1_Component = Mul (A1, CosRangeAngular) - Temp0 = Add (A1_Component, A2_Component) - Temp1 = Sub (A0, Temp0) - output = Cast (Temp1) - } - )ONNX"); + .TypeConstraint("T1", {"tensor(int32)", "tensor(int64)"}, "Constrain the input size to int64_t.") + .TypeConstraint("T2", ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), + "Constrain output types to numeric tensors.") + .FunctionBody({WindowOpFunctionNode("Hann")}, {onnx_op_set_17}); MS_SIGNAL_OPERATOR_SCHEMA(HammingWindow) .SetDomain(kMSExperimentalDomain) .SinceVersion(1) .FillUsing(CosineSumWindowOpDocGenerator("Hamming")) - .TypeConstraint( - "T1", - {"tensor(int32)", "tensor(int64)"}, - "Constrain the input size to int64_t.") - .TypeConstraint( - "T2", - ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), - "Constrain output types to numeric tensors.") - .FunctionBody(R"ONNX( - { - A0 = Constant () - A1 = Constant () - A2 = Constant () - Zero = Constant () - One = Constant () - Two = Constant () - Tau = Constant () - Size_FP = Cast (size) - AngularIncrement = Div (Tau, Size_FP) - Range = Range (Zero, Size_FP, One) - RangeAngular = Mul (Range, AngularIncrement) - TwoRangeAngular = Mul (RangeAngular, Two) - CosTwoRangeAngular = Cos (TwoRangeAngular) - A2_Component = Mul (A2, CosTwoRangeAngular) - CosRangeAngular = Cos (RangeAngular) - A1_Component = Mul (A1, CosRangeAngular) - Temp0 = Add (A1_Component, A2_Component) - Temp1 = Sub (A0, Temp0) - output = Cast (Temp1) - } - )ONNX"); + .TypeConstraint("T1", {"tensor(int32)", "tensor(int64)"}, "Constrain the input size to int64_t.") + .TypeConstraint("T2", ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), + "Constrain output types to numeric tensors.") + .FunctionBody({WindowOpFunctionNode("Hamming")}, {onnx_op_set_17}); MS_SIGNAL_OPERATOR_SCHEMA(BlackmanWindow) .SetDomain(kMSExperimentalDomain) .SinceVersion(1) .FillUsing(CosineSumWindowOpDocGenerator("Blackman")) - .TypeConstraint( - "T1", - {"tensor(int32)", "tensor(int64)"}, - "Constrain the input size to int64_t.") - .TypeConstraint( - "T2", - ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), - "Constrain output types to numeric tensors.") - .FunctionBody(R"ONNX( - { - A0 = Constant () - A1 = Constant () - A2 = Constant () - Zero = Constant () - One = Constant () - Two = Constant () - Tau = Constant () - Size_FP = Cast (size) - AngularIncrement = Div (Tau, Size_FP) - Range = Range (Zero, Size_FP, One) - RangeAngular = Mul (Range, AngularIncrement) - TwoRangeAngular = Mul (RangeAngular, Two) - CosTwoRangeAngular = Cos (TwoRangeAngular) - A2_Component = Mul (A2, CosTwoRangeAngular) - CosRangeAngular = Cos (RangeAngular) - A1_Component = Mul (A1, CosRangeAngular) - Temp0 = Add (A1_Component, A2_Component) - Temp1 = Sub (A0, Temp0) - output = Cast (Temp1) - } - )ONNX"); - - static const char* MelWeightMatrix_ver17_doc = R"DOC( + .TypeConstraint("T1", {"tensor(int32)", "tensor(int64)"}, "Constrain the input size to int64_t.") + .TypeConstraint("T2", ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), + "Constrain output types to numeric tensors.") + .FunctionBody({WindowOpFunctionNode("Blackman")}, {onnx_op_set_17}); + + ONNX_NAMESPACE::NodeProto mel_function_node; + stft_function_node.set_op_type("MelWeightMatrix"); + + auto* mel_function_output_datatype = idft_function_node.add_attribute(); + mel_function_output_datatype->set_name("output_datatype"); + mel_function_output_datatype->set_ref_attr_name("output_datatype"); + + mel_function_node.add_input("num_mel_bins"); + mel_function_node.add_input("dft_length"); + mel_function_node.add_input("sample_rate"); + mel_function_node.add_input("lower_edge_hertz"); + mel_function_node.add_input("upper_edge_hertz"); + mel_function_node.add_output("output"); + + static const char* MelWeightMatrix_doc = R"DOC( Generate a MelWeightMatrix that can be used to re-weight a Tensor containing a linearly sampled frequency spectra -(from DFT or STFT) into num_mel_bins frequency information based on the [lower_edge_hertz, upper_edge_hertz] range -on the mel scale. -This function defines the mel scale in terms of a frequency in hertz according to the following formula: +(from DFT or STFT) into num_mel_bins frequency information based on the [lower_edge_hertz, upper_edge_hertz] range on +the mel scale. This function defines the mel scale in terms of a frequency in hertz according to the following formula: mel(f) = 2595 * log10(1 + f/700) @@ -651,51 +576,23 @@ linear scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogr MS_SIGNAL_OPERATOR_SCHEMA(MelWeightMatrix) .SetDomain(kMSExperimentalDomain) .SinceVersion(1) - .SetDoc(R"DOC(MelWeightMatrix)DOC") + .SetDoc(MelWeightMatrix_doc) .Attr("output_datatype", "The data type of the output tensor. " "Strictly must be one of the types from DataType enum in TensorProto.", ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT, static_cast(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)) - .Input(0, - "num_mel_bins", - "The number of bands in the mel spectrum.", - "T1") - .Input(1, - "dft_length", - "The size of the FFT.", - "T1") - .Input(2, - "sample_rate", - "", - "T1") - .Input(3, - "lower_edge_hertz", - "", - "T2") - .Input(4, - "upper_edge_hertz", - "", - "T2") - .Output(0, - "output", - "The MEL Matrix", - "T3") - .TypeConstraint( - "T1", - {"tensor(int32)", "tensor(int64)"}, - "Constrain to integer tensors.") - .TypeConstraint( - "T2", - {"tensor(float)", - "tensor(float16)", - "tensor(double)", - "tensor(bfloat16)"}, - "Constrain to float tensors") - .TypeConstraint( - "T3", - ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), - "Constrain to any numerical types.") + .Input(0, "num_mel_bins", "The number of bands in the mel spectrum.", "T1") + .Input(1, "dft_length", "The size of the FFT.", "T1") + .Input(2, "sample_rate", "", "T1") + .Input(3, "lower_edge_hertz", "", "T2") + .Input(4, "upper_edge_hertz", "", "T2") + .Output(0, "output", "The MEL Matrix", "T3") + .TypeConstraint("T1", {"tensor(int32)", "tensor(int64)"}, "Constrain to integer tensors.") + .TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain to float tensors") + .TypeConstraint("T3", ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), + "Constrain to any numerical types.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { auto output_datatype = getAttribute( ctx, "output_datatype", static_cast(onnx::TensorProto::DataType::TensorProto_DataType_FLOAT)); @@ -729,7 +626,8 @@ linear scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogr result_shape.add_dim()->set_dim_value(num_mel_bins_value); updateOutputShape(ctx, 0, result_shape); } - }); + }) + .FunctionBody({mel_function_node}, {onnx_op_set_17}); } } // namespace signal diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 62e1d1f73f353..0e76fabdfffc9 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -137,8 +137,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, TopK); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, double, TopK); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Conv); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, Flatten); @@ -159,8 +161,7 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, - ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, @@ -257,17 +258,28 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, Sign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, Erf); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_int64_t_int64_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_int64_t_int64_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_string_int64_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_string_int64_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_float_float, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_int32_t_float, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int64_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_int32_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_float, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_float, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int32_t, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, + int64_t_int64_t_int64_t, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_int64_t_int64_t, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_string_int64_t, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_string_int64_t, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_float_float, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_int32_t_float, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int64_t, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_int32_t, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_float, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_float, + OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int32_t, + OneHot); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MaxUnpool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh); @@ -295,8 +307,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 15, PRelu); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, float, Upsample); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int32_t, Upsample); @@ -314,11 +328,16 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int8_t, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, + DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, + DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, + DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, + QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, + QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QLinearMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger); @@ -362,10 +381,14 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, + ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, + ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, + ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, + ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMax); @@ -384,9 +407,12 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, + ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, + ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, + ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, Hardmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, LogSoftmax); @@ -480,15 +506,22 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Ei // class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, MLFloat16_float, Dropout); // class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, MLFloat16_double, Dropout); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_float, Dropout); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_double, Dropout); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_float, Dropout); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_double, Dropout); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_double, + Dropout); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_float, + Dropout); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_double, + Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Celu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, float, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, double, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int32_t, GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int64_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, float, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, double, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int32_t, + GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int64_t, + GreaterOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, float, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, double, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int32_t, LessOrEqual); @@ -706,8 +739,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Div); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 15, Identity); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float, + BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, double, + BatchNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, GRU); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, LSTM); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, RNN); @@ -751,6 +786,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int32_t, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int64_t, LessOrEqual); +// Opset 17 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, DFT); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, BlackmanWindow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HammingWindow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HannWindow); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, MelWeightMatrix); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, STFT); + // !!PLEASE READ BELOW!! Following that, add new entries above this comment /* *** IMPORTANT! *** @@ -795,95 +838,123 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -914,7 +985,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { double, Equal)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -922,8 +994,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1047,31 +1117,28 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 9 BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1188,36 +1253,36 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { int64_t, NonZero)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1246,27 +1311,30 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // opset 11 - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1299,10 +1366,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { float, Equal)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1358,15 +1430,17 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // REVIEW(codemzs): ConstEigenVectorArrayMap.cast, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 13 BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1851,53 +1912,41 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Softmax)>, // OpSet 14 - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1931,9 +1980,12 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1953,6 +2005,13 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { LessOrEqual)>, BuildKernelCreateInfo, + // Opset 17 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -1999,12 +2058,18 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int32_t, Scaler); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, SVMClassifier); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, SVMRegressor); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int64_t, TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int32_t, TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, TreeEnsembleRegressor); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, TreeEnsembleRegressor); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, + TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, + TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int64_t, + TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int32_t, + TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, + TreeEnsembleRegressor); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, + TreeEnsembleRegressor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, ZipMap); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_string, LabelEncoder); @@ -2072,26 +2137,25 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo +#include +#include +#include +#include -#include "core/providers/common.h" #include "core/framework/op_kernel.h" +#include "core/platform/threadpool.h" +#include "core/providers/common.h" +#include "core/providers/cpu/signal/utils.h" #include "core/util/math_cpuonly.h" #include "Eigen/src/Core/Map.h" -#include "dft.h" -#include -#include "core/platform/threadpool.h" +namespace onnxruntime { -#include -#include +ONNX_CPU_OPERATOR_KERNEL(DFT, 17, + KernelDefBuilder() + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", BuildKernelDefConstraints()), + DFT); -namespace onnxruntime { -namespace contrib { - -ONNX_OPERATOR_KERNEL_EX( - DFT, - kMSExperimentalDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().TypeConstraint("T1", BuildKernelDefConstraints()) - .TypeConstraint("T2", BuildKernelDefConstraints()), - DFT); - -ONNX_OPERATOR_KERNEL_EX( - IDFT, - kMSExperimentalDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().TypeConstraint("T1", BuildKernelDefConstraints()) - .TypeConstraint("T2", BuildKernelDefConstraints()), - IDFT); - -ONNX_OPERATOR_KERNEL_EX( - STFT, - kMSExperimentalDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T1", BuildKernelDefConstraints()) - .TypeConstraint("T2", BuildKernelDefConstraints()), - STFT); - -// dedupe with the other one in window_functions.cc -template -static T get_scalar_value_from_tensor(const Tensor* tensor) { - ORT_ENFORCE(tensor->Shape().Size() == 1, "ratio input should have a single value."); - - auto data_type = tensor->DataType()->AsPrimitiveDataType()->GetDataType(); - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - default: - ORT_THROW("Unsupported input data type of ", data_type); - } -} +ONNX_CPU_OPERATOR_KERNEL(STFT, 17, + KernelDefBuilder() + .MayInplace(0, 0) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", BuildKernelDefConstraints()), + STFT); -static bool is_real_valued_signal(const onnxruntime::TensorShape & shape) { +static bool is_real_valued_signal(const onnxruntime::TensorShape& shape) { return shape.NumDimensions() == 2 || shape[shape.NumDimensions() - 1] == 1; } @@ -82,71 +48,100 @@ static bool is_power_of_2(size_t size) { return n_bits == 1; } -static const unsigned char BitReverseTable256[] = -{ - 0x00, 0x80, 0x40, 0xC0, 0x20, 0xA0, 0x60, 0xE0, 0x10, 0x90, 0x50, 0xD0, 0x30, 0xB0, 0x70, 0xF0, - 0x08, 0x88, 0x48, 0xC8, 0x28, 0xA8, 0x68, 0xE8, 0x18, 0x98, 0x58, 0xD8, 0x38, 0xB8, 0x78, 0xF8, - 0x04, 0x84, 0x44, 0xC4, 0x24, 0xA4, 0x64, 0xE4, 0x14, 0x94, 0x54, 0xD4, 0x34, 0xB4, 0x74, 0xF4, - 0x0C, 0x8C, 0x4C, 0xCC, 0x2C, 0xAC, 0x6C, 0xEC, 0x1C, 0x9C, 0x5C, 0xDC, 0x3C, 0xBC, 0x7C, 0xFC, - 0x02, 0x82, 0x42, 0xC2, 0x22, 0xA2, 0x62, 0xE2, 0x12, 0x92, 0x52, 0xD2, 0x32, 0xB2, 0x72, 0xF2, - 0x0A, 0x8A, 0x4A, 0xCA, 0x2A, 0xAA, 0x6A, 0xEA, 0x1A, 0x9A, 0x5A, 0xDA, 0x3A, 0xBA, 0x7A, 0xFA, - 0x06, 0x86, 0x46, 0xC6, 0x26, 0xA6, 0x66, 0xE6, 0x16, 0x96, 0x56, 0xD6, 0x36, 0xB6, 0x76, 0xF6, - 0x0E, 0x8E, 0x4E, 0xCE, 0x2E, 0xAE, 0x6E, 0xEE, 0x1E, 0x9E, 0x5E, 0xDE, 0x3E, 0xBE, 0x7E, 0xFE, - 0x01, 0x81, 0x41, 0xC1, 0x21, 0xA1, 0x61, 0xE1, 0x11, 0x91, 0x51, 0xD1, 0x31, 0xB1, 0x71, 0xF1, - 0x09, 0x89, 0x49, 0xC9, 0x29, 0xA9, 0x69, 0xE9, 0x19, 0x99, 0x59, 0xD9, 0x39, 0xB9, 0x79, 0xF9, - 0x05, 0x85, 0x45, 0xC5, 0x25, 0xA5, 0x65, 0xE5, 0x15, 0x95, 0x55, 0xD5, 0x35, 0xB5, 0x75, 0xF5, - 0x0D, 0x8D, 0x4D, 0xCD, 0x2D, 0xAD, 0x6D, 0xED, 0x1D, 0x9D, 0x5D, 0xDD, 0x3D, 0xBD, 0x7D, 0xFD, - 0x03, 0x83, 0x43, 0xC3, 0x23, 0xA3, 0x63, 0xE3, 0x13, 0x93, 0x53, 0xD3, 0x33, 0xB3, 0x73, 0xF3, - 0x0B, 0x8B, 0x4B, 0xCB, 0x2B, 0xAB, 0x6B, 0xEB, 0x1B, 0x9B, 0x5B, 0xDB, 0x3B, 0xBB, 0x7B, 0xFB, - 0x07, 0x87, 0x47, 0xC7, 0x27, 0xA7, 0x67, 0xE7, 0x17, 0x97, 0x57, 0xD7, 0x37, 0xB7, 0x77, 0xF7, - 0x0F, 0x8F, 0x4F, 0xCF, 0x2F, 0xAF, 0x6F, 0xEF, 0x1F, 0x9F, 0x5F, 0xDF, 0x3F, 0xBF, 0x7F, 0xFF}; +static const unsigned char BitReverseTable256[] = { + 0x00, 0x80, 0x40, 0xC0, 0x20, 0xA0, 0x60, 0xE0, 0x10, 0x90, 0x50, 0xD0, 0x30, 0xB0, 0x70, 0xF0, 0x08, 0x88, 0x48, + 0xC8, 0x28, 0xA8, 0x68, 0xE8, 0x18, 0x98, 0x58, 0xD8, 0x38, 0xB8, 0x78, 0xF8, 0x04, 0x84, 0x44, 0xC4, 0x24, 0xA4, + 0x64, 0xE4, 0x14, 0x94, 0x54, 0xD4, 0x34, 0xB4, 0x74, 0xF4, 0x0C, 0x8C, 0x4C, 0xCC, 0x2C, 0xAC, 0x6C, 0xEC, 0x1C, + 0x9C, 0x5C, 0xDC, 0x3C, 0xBC, 0x7C, 0xFC, 0x02, 0x82, 0x42, 0xC2, 0x22, 0xA2, 0x62, 0xE2, 0x12, 0x92, 0x52, 0xD2, + 0x32, 0xB2, 0x72, 0xF2, 0x0A, 0x8A, 0x4A, 0xCA, 0x2A, 0xAA, 0x6A, 0xEA, 0x1A, 0x9A, 0x5A, 0xDA, 0x3A, 0xBA, 0x7A, + 0xFA, 0x06, 0x86, 0x46, 0xC6, 0x26, 0xA6, 0x66, 0xE6, 0x16, 0x96, 0x56, 0xD6, 0x36, 0xB6, 0x76, 0xF6, 0x0E, 0x8E, + 0x4E, 0xCE, 0x2E, 0xAE, 0x6E, 0xEE, 0x1E, 0x9E, 0x5E, 0xDE, 0x3E, 0xBE, 0x7E, 0xFE, 0x01, 0x81, 0x41, 0xC1, 0x21, + 0xA1, 0x61, 0xE1, 0x11, 0x91, 0x51, 0xD1, 0x31, 0xB1, 0x71, 0xF1, 0x09, 0x89, 0x49, 0xC9, 0x29, 0xA9, 0x69, 0xE9, + 0x19, 0x99, 0x59, 0xD9, 0x39, 0xB9, 0x79, 0xF9, 0x05, 0x85, 0x45, 0xC5, 0x25, 0xA5, 0x65, 0xE5, 0x15, 0x95, 0x55, + 0xD5, 0x35, 0xB5, 0x75, 0xF5, 0x0D, 0x8D, 0x4D, 0xCD, 0x2D, 0xAD, 0x6D, 0xED, 0x1D, 0x9D, 0x5D, 0xDD, 0x3D, 0xBD, + 0x7D, 0xFD, 0x03, 0x83, 0x43, 0xC3, 0x23, 0xA3, 0x63, 0xE3, 0x13, 0x93, 0x53, 0xD3, 0x33, 0xB3, 0x73, 0xF3, 0x0B, + 0x8B, 0x4B, 0xCB, 0x2B, 0xAB, 0x6B, 0xEB, 0x1B, 0x9B, 0x5B, 0xDB, 0x3B, 0xBB, 0x7B, 0xFB, 0x07, 0x87, 0x47, 0xC7, + 0x27, 0xA7, 0x67, 0xE7, 0x17, 0x97, 0x57, 0xD7, 0x37, 0xB7, 0x77, 0xF7, 0x0F, 0x8F, 0x4F, 0xCF, 0x2F, 0xAF, 0x6F, + 0xEF, 0x1F, 0x9F, 0x5F, 0xDF, 0x3F, 0xBF, 0x7F, 0xFF}; template uint32_t bit_reverse(uint32_t num) { - uint32_t rev = (BitReverseTable256[num & 0xff] << 24) | - (BitReverseTable256[(num >> 8) & 0xff] << 16) | - (BitReverseTable256[(num >> 16) & 0xff] << 8) | - (BitReverseTable256[(num >> 24) & 0xff]); + uint32_t rev = (BitReverseTable256[num & 0xff] << 24) | (BitReverseTable256[(num >> 8) & 0xff] << 16) | + (BitReverseTable256[(num >> 16) & 0xff] << 8) | (BitReverseTable256[(num >> 24) & 0xff]); return static_cast(((uint64_t)rev) >> (32 - TSignificantBits)); } template static inline T bit_reverse(T num, unsigned significant_bits) { switch (significant_bits) { - case 0: return static_cast(bit_reverse<0>(static_cast(num))); - case 1: return static_cast(bit_reverse<1>(static_cast(num))); - case 2: return static_cast(bit_reverse<2>(static_cast(num))); - case 3: return static_cast(bit_reverse<3>(static_cast(num))); - case 4: return static_cast(bit_reverse<4>(static_cast(num))); - case 5: return static_cast(bit_reverse<5>(static_cast(num))); - case 6: return static_cast(bit_reverse<6>(static_cast(num))); - case 7: return static_cast(bit_reverse<7>(static_cast(num))); - case 8: return static_cast(bit_reverse<8>(static_cast(num))); - case 9: return static_cast(bit_reverse<9>(static_cast(num))); - case 10: return static_cast(bit_reverse<10>(static_cast(num))); - case 11: return static_cast(bit_reverse<11>(static_cast(num))); - case 12: return static_cast(bit_reverse<12>(static_cast(num))); - case 13: return static_cast(bit_reverse<13>(static_cast(num))); - case 14: return static_cast(bit_reverse<14>(static_cast(num))); - case 15: return static_cast(bit_reverse<15>(static_cast(num))); - case 16: return static_cast(bit_reverse<16>(static_cast(num))); - case 17: return static_cast(bit_reverse<17>(static_cast(num))); - case 18: return static_cast(bit_reverse<18>(static_cast(num))); - case 19: return static_cast(bit_reverse<19>(static_cast(num))); - case 20: return static_cast(bit_reverse<20>(static_cast(num))); - case 21: return static_cast(bit_reverse<21>(static_cast(num))); - case 22: return static_cast(bit_reverse<22>(static_cast(num))); - case 23: return static_cast(bit_reverse<23>(static_cast(num))); - case 24: return static_cast(bit_reverse<24>(static_cast(num))); - case 25: return static_cast(bit_reverse<25>(static_cast(num))); - case 26: return static_cast(bit_reverse<26>(static_cast(num))); - case 27: return static_cast(bit_reverse<27>(static_cast(num))); - case 28: return static_cast(bit_reverse<28>(static_cast(num))); - case 29: return static_cast(bit_reverse<29>(static_cast(num))); - case 30: return static_cast(bit_reverse<30>(static_cast(num))); - case 31: return static_cast(bit_reverse<31>(static_cast(num))); - case 32: return static_cast(bit_reverse<32>(static_cast(num))); - default: ORT_THROW("Unsupported bit size."); + case 0: + return static_cast(bit_reverse<0>(static_cast(num))); + case 1: + return static_cast(bit_reverse<1>(static_cast(num))); + case 2: + return static_cast(bit_reverse<2>(static_cast(num))); + case 3: + return static_cast(bit_reverse<3>(static_cast(num))); + case 4: + return static_cast(bit_reverse<4>(static_cast(num))); + case 5: + return static_cast(bit_reverse<5>(static_cast(num))); + case 6: + return static_cast(bit_reverse<6>(static_cast(num))); + case 7: + return static_cast(bit_reverse<7>(static_cast(num))); + case 8: + return static_cast(bit_reverse<8>(static_cast(num))); + case 9: + return static_cast(bit_reverse<9>(static_cast(num))); + case 10: + return static_cast(bit_reverse<10>(static_cast(num))); + case 11: + return static_cast(bit_reverse<11>(static_cast(num))); + case 12: + return static_cast(bit_reverse<12>(static_cast(num))); + case 13: + return static_cast(bit_reverse<13>(static_cast(num))); + case 14: + return static_cast(bit_reverse<14>(static_cast(num))); + case 15: + return static_cast(bit_reverse<15>(static_cast(num))); + case 16: + return static_cast(bit_reverse<16>(static_cast(num))); + case 17: + return static_cast(bit_reverse<17>(static_cast(num))); + case 18: + return static_cast(bit_reverse<18>(static_cast(num))); + case 19: + return static_cast(bit_reverse<19>(static_cast(num))); + case 20: + return static_cast(bit_reverse<20>(static_cast(num))); + case 21: + return static_cast(bit_reverse<21>(static_cast(num))); + case 22: + return static_cast(bit_reverse<22>(static_cast(num))); + case 23: + return static_cast(bit_reverse<23>(static_cast(num))); + case 24: + return static_cast(bit_reverse<24>(static_cast(num))); + case 25: + return static_cast(bit_reverse<25>(static_cast(num))); + case 26: + return static_cast(bit_reverse<26>(static_cast(num))); + case 27: + return static_cast(bit_reverse<27>(static_cast(num))); + case 28: + return static_cast(bit_reverse<28>(static_cast(num))); + case 29: + return static_cast(bit_reverse<29>(static_cast(num))); + case 30: + return static_cast(bit_reverse<30>(static_cast(num))); + case 31: + return static_cast(bit_reverse<31>(static_cast(num))); + case 32: + return static_cast(bit_reverse<32>(static_cast(num))); + default: + ORT_THROW("Unsupported bit size."); } } @@ -161,13 +156,10 @@ static T compute_angular_velocity(size_t number_of_samples, bool inverse) { } template -static Status fft_radix2(OpKernelContext* /*ctx*/, - const Tensor* X, Tensor* Y, - size_t X_offset, size_t X_stride, size_t Y_offset, size_t Y_stride, int64_t axis, size_t dft_length, - const Tensor* window, bool is_onesided, bool inverse, - std::vector>& V, - std::vector>& temp_output) { - +static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, size_t X_offset, size_t X_stride, + size_t Y_offset, size_t Y_stride, int64_t axis, size_t dft_length, const Tensor* window, + bool is_onesided, bool inverse, InlinedVector>& V, + InlinedVector>& temp_output) { // Get shape and significant bits const auto& X_shape = X->Shape(); size_t number_of_samples = static_cast(X_shape[axis]); @@ -185,7 +177,7 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, std::complex* Y_data; if (is_onesided) { if (temp_output.size() != dft_length) { - temp_output = std::vector>(dft_length); + temp_output = InlinedVector>(dft_length); } Y_data = temp_output.data(); } else { @@ -197,18 +189,19 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, // Create vandermonde matrix V ordered with the bit-reversed permutation if (V.size() != dft_length) { - V = std::vector>(dft_length); // e^(i *2*pi / N * k) + V = InlinedVector>(dft_length); // e^(i *2*pi / N * k) for (size_t i = 0; i < dft_length; i++) { size_t bit_reversed_index = bit_reverse(i, significant_bits); - V[bit_reversed_index] = std::complex(cos(i * angular_velocity), sin(i * angular_velocity)); + V[bit_reversed_index] = std::complex(cos(static_cast(i) * angular_velocity), + sin(static_cast(i) * angular_velocity)); } } for (size_t i = 0; i < dft_length; i++) { size_t bit_reversed_index = bit_reverse(i, significant_bits); - auto x = (bit_reversed_index < number_of_samples) ? * (X_data + bit_reversed_index * X_stride) : 0; + auto x = (bit_reversed_index < number_of_samples) ? *(X_data + bit_reversed_index * X_stride) : 0; auto window_element = window_data ? *(window_data + bit_reversed_index) : 1; - *(Y_data + i*Y_data_stride) = std::complex(1, 0) * x * window_element; + *(Y_data + i * Y_data_stride) = std::complex(1, 0) * x * window_element; } // Run fft_radix2 @@ -222,7 +215,7 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, auto second_idx = bit_reverse(midpoint + k, current_significant_bits); for (size_t j = 0; j < dft_length; j += i) { auto even_index = k + j; - auto odd_index = k + j + midpoint; + auto odd_index = k + j + midpoint; std::complex* even = (Y_data + even_index * Y_data_stride); std::complex* odd = (Y_data + odd_index * Y_data_stride); std::complex first = *even + (V[first_idx] * *odd); @@ -252,9 +245,8 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, } template -static Status dft_naive(const Tensor* X, Tensor* Y, - size_t X_offset, size_t X_stride, size_t Y_offset, size_t Y_stride, int64_t axis, - size_t dft_length, const Tensor* window, bool inverse) { +static Status dft_naive(const Tensor* X, Tensor* Y, size_t X_offset, size_t X_stride, size_t Y_offset, size_t Y_stride, + int64_t axis, size_t dft_length, const Tensor* window, bool inverse) { // Get shape and significant bits const auto& X_shape = X->Shape(); size_t number_of_samples = static_cast(X_shape[axis]); @@ -273,13 +265,15 @@ static Status dft_naive(const Tensor* X, Tensor* Y, auto angular_velocity = compute_angular_velocity(dft_length, inverse); for (size_t i = 0; i < dft_output_size; i++) { - std::complex& out = *(Y_data + i*Y_stride); + std::complex& out = *(Y_data + i * Y_stride); out.real(0); out.imag(0); for (size_t j = 0; j < dft_length; j++) { // vectorize over this loop - auto exponential = std::complex(cos(i * j * angular_velocity), sin(i * j * angular_velocity)); - auto window_element = window_data ? * (window_data + j) : 1; + auto exponential = std::complex( + cos(static_cast(i) * static_cast(j) * angular_velocity), + sin(static_cast(i) * static_cast(j) * angular_velocity)); + auto window_element = window_data ? *(window_data + j) : 1; auto x = (j < number_of_samples) ? *(X_data + j * X_stride) : 0; auto element = x * window_element; out += exponential * element; @@ -294,8 +288,10 @@ static Status dft_naive(const Tensor* X, Tensor* Y, } template -static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X, Tensor* Y, int64_t axis, int64_t dft_length, const Tensor* window, bool is_onesided, bool inverse, - std::vector>& V, std::vector>& temp_output) { +static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X, Tensor* Y, int64_t axis, + int64_t dft_length, const Tensor* window, bool is_onesided, bool inverse, + InlinedVector>& V, + InlinedVector>& temp_output) { // Get shape const auto& X_shape = X->Shape(); const auto& Y_shape = Y->Shape(); @@ -305,22 +301,19 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X, auto is_input_real = X->Shape().NumDimensions() == 2 || X->Shape()[X->Shape().NumDimensions() - 1] == 1; auto complex_input_factor = is_input_real ? 1 : 2; - if (X->Shape().NumDimensions() > 2) - { + if (X->Shape().NumDimensions() > 2) { total_dfts /= X->Shape()[X->Shape().NumDimensions() - 1]; batch_and_signal_rank -= 1; } // Calculate x/y offsets/strides - for (size_t i = 0; i < total_dfts; i++) - { + for (size_t i = 0; i < total_dfts; i++) { size_t X_offset = 0; - size_t X_stride = X_shape.SizeFromDimension(axis+1) / complex_input_factor; + size_t X_stride = X_shape.SizeFromDimension(axis + 1) / complex_input_factor; size_t cumulative_packed_stride = total_dfts; size_t temp = i; for (size_t r = 0; r < batch_and_signal_rank; r++) { - if (r == static_cast(axis)) - { + if (r == static_cast(axis)) { continue; } cumulative_packed_stride /= X_shape[r]; @@ -334,8 +327,7 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X, cumulative_packed_stride = total_dfts; temp = i; for (size_t r = 0; r < batch_and_signal_rank; r++) { - if (r == static_cast(axis)) - { + if (r == static_cast(axis)) { continue; } cumulative_packed_stride /= X_shape[r]; @@ -345,9 +337,11 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, const Tensor* X, } if (is_power_of_2(dft_length)) { - ORT_RETURN_IF_ERROR((fft_radix2(ctx, X, Y, X_offset, X_stride, Y_offset, Y_stride, axis, dft_length, window, is_onesided, inverse, V, temp_output))); + ORT_RETURN_IF_ERROR((fft_radix2(ctx, X, Y, X_offset, X_stride, Y_offset, Y_stride, axis, dft_length, window, + is_onesided, inverse, V, temp_output))); } else { - ORT_RETURN_IF_ERROR((dft_naive(X, Y, X_offset, X_stride, Y_offset, Y_stride, axis, dft_length, window, inverse))); + ORT_RETURN_IF_ERROR( + (dft_naive(X, Y, X_offset, X_stride, Y_offset, Y_stride, axis, dft_length, window, inverse))); } } @@ -366,10 +360,7 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo // Ensure that the axis is in the valid range of [-rank, rank) auto rank = static_cast(X_shape.GetDims().size()); if (!(-rank <= axis && axis < rank)) { - ORT_RETURN_IF(!(-rank <= axis && axis < rank), - "axis attribute value ", - axis, - " is invalid for a tensor of rank ", + ORT_RETURN_IF(!(-rank <= axis && axis < rank), "axis attribute value ", axis, " is invalid for a tensor of rank ", rank); } axis = (axis >= 0 ? axis : axis + rank); @@ -378,23 +369,19 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo if (dft_length) { const auto& dft_length_shape = dft_length->Shape(); ORT_RETURN_IF(!dft_length_shape.IsScalar(), "dft_length must be a scalar value."); - number_of_samples = static_cast(get_scalar_value_from_tensor(dft_length)); + number_of_samples = static_cast(signal::get_scalar_value_from_tensor(dft_length)); ORT_RETURN_IF(number_of_samples <= 0, "dft_length must be greater than zero."); } // Get the DFT output size. Onesided will return only the unique values! // note: x >> 1 === std::floor(x / 2.f) - auto dft_output_size = is_onesided ? - ((number_of_samples >> 1) + 1) : - number_of_samples; + auto dft_output_size = is_onesided ? ((number_of_samples >> 1) + 1) : number_of_samples; // Get output shape auto Y_shape = onnxruntime::TensorShape(X_shape); - if (X_shape.NumDimensions() == 2) - { + if (X_shape.NumDimensions() == 2) { Y_shape = onnxruntime::TensorShape({X_shape[0], dft_output_size, 2}); - } else - { + } else { Y_shape[Y_shape.NumDimensions() - 1] = 2; } Y_shape[axis] = dft_output_size; @@ -405,24 +392,36 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo auto element_size = data_type->Size(); if (element_size == sizeof(float)) { - std::vector> V; - std::vector> temp_output; + InlinedVector> V; + InlinedVector> temp_output; if (is_real_valued) { - ORT_RETURN_IF_ERROR((discrete_fourier_transform(ctx, X, Y, axis, number_of_samples, nullptr, is_onesided, inverse, V, temp_output))); + ORT_RETURN_IF_ERROR((discrete_fourier_transform(ctx, X, Y, axis, number_of_samples, nullptr, + is_onesided, inverse, V, temp_output))); } else if (is_complex_valued) { - ORT_RETURN_IF_ERROR((discrete_fourier_transform>(ctx, X, Y, axis, number_of_samples, nullptr, is_onesided, inverse, V, temp_output))); + ORT_RETURN_IF_ERROR((discrete_fourier_transform>( + ctx, X, Y, axis, number_of_samples, nullptr, is_onesided, inverse, V, temp_output))); } else { - ORT_THROW("Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for complex inputs.", data_type); + ORT_THROW( + "Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second " + "dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for " + "complex inputs.", + data_type); } } else if (element_size == sizeof(double)) { - std::vector> V; - std::vector> temp_output; + InlinedVector> V; + InlinedVector> temp_output; if (is_real_valued) { - ORT_RETURN_IF_ERROR((discrete_fourier_transform(ctx, X, Y, axis, number_of_samples, nullptr, is_onesided, inverse, V, temp_output))); + ORT_RETURN_IF_ERROR((discrete_fourier_transform(ctx, X, Y, axis, number_of_samples, nullptr, + is_onesided, inverse, V, temp_output))); } else if (is_complex_valued) { - ORT_RETURN_IF_ERROR((discrete_fourier_transform>(ctx, X, Y, axis, number_of_samples, nullptr, is_onesided, inverse, V, temp_output))); + ORT_RETURN_IF_ERROR((discrete_fourier_transform>( + ctx, X, Y, axis, number_of_samples, nullptr, is_onesided, inverse, V, temp_output))); } else { - ORT_THROW("Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for complex inputs.", data_type); + ORT_THROW( + "Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second " + "dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for " + "complex inputs.", + data_type); } } else { ORT_THROW("Unsupported input data type of ", data_type); @@ -432,20 +431,7 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo } Status DFT::Compute(OpKernelContext* ctx) const { - ORT_RETURN_IF_ERROR( - discrete_fourier_transform(ctx, - axis_, - is_onesided_, - is_inverse_)); - return Status::OK(); -} - -Status IDFT::Compute(OpKernelContext* ctx) const { - ORT_RETURN_IF_ERROR( - discrete_fourier_transform(ctx, - axis_, - false /*is_onesided_*/, - true /*is_inverse_*/)); + ORT_RETURN_IF_ERROR(discrete_fourier_transform(ctx, axis_, is_onesided_, is_inverse_)); return Status::OK(); } @@ -460,7 +446,7 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside // Get signal const auto* signal = ctx->Input(0); - const auto frame_step = get_scalar_value_from_tensor(ctx->Input(1)); + const auto frame_step = signal::get_scalar_value_from_tensor(ctx->Input(1)); const auto* window = ctx->Input(2); const auto* frame_length_tensor = ctx->Input(3); @@ -468,27 +454,28 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside const auto& signal_shape = signal->Shape(); const auto batch_size = signal_shape[0]; const auto signal_size = signal_shape[1]; - const auto signal_components = - signal_shape.NumDimensions() == 2 ? 1 : signal_shape.NumDimensions() == 3 ? signal_shape[2] : 0; // error + const auto signal_components = signal_shape.NumDimensions() == 2 ? 1 + : signal_shape.NumDimensions() == 3 ? signal_shape[2] + : 0; // error ORT_ENFORCE(signal_components == 1 || signal_components == 2, "Ensure that the signal has either 1 or 2 components."); // Get the frame length int64_t frame_length = std::numeric_limits::min(); - if (frame_length_tensor) - { - frame_length = get_scalar_value_from_tensor(frame_length_tensor); + if (frame_length_tensor) { + frame_length = signal::get_scalar_value_from_tensor(frame_length_tensor); } // Get window length int64_t window_length = std::numeric_limits::min(); - if (window) { + if (window) { window_length = window->Shape()[0]; } - // The frame_length and window inputs are generally used interchangably, and should match! - if (frame_length != std::numeric_limits::min() && - window_length != std::numeric_limits::min()) { - ORT_ENFORCE(frame_length == window_length, "If both frame_length and window are set, then the size of the window must be equal to the frame_length."); + // The frame_length and window inputs are generally used interchangeably, and should match! + if (frame_length != std::numeric_limits::min() && window_length != std::numeric_limits::min()) { + ORT_ENFORCE( + frame_length == window_length, + "If both frame_length and window are set, then the size of the window must be equal to the frame_length."); } // Calculate the window size with preference to the window input. @@ -496,14 +483,12 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside ORT_ENFORCE(window_size < signal_size, "Ensure that the dft size is smaller than the signal."); // Calculate the number of dfts to run - const auto n_dfts = static_cast(std::floor((signal_size - window_size) / static_cast(frame_step)) + 1); + const auto n_dfts = + static_cast(std::floor((signal_size - window_size) / static_cast(frame_step)) + 1); // Calculate the output spectra length (onesided will return only the unique values) // note: x >> 1 === std::floor(x / 2.f) - const auto dft_output_size = - is_onesided ? - (window_size >> 1) + 1 : - window_size; + const auto dft_output_size = is_onesided ? (window_size >> 1) + 1 : window_size; // Get/create the output mutable data auto output_spectra_shape = onnxruntime::TensorShape({batch_size, n_dfts, dft_output_size, 2}); @@ -518,41 +503,26 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside auto dft_input_shape = onnxruntime::TensorShape({1, window_size, signal_components}); auto dft_output_shape = onnxruntime::TensorShape({1, dft_output_size, output_components}); - std::vector> V; - std::vector> temp_output; + InlinedVector> V; + InlinedVector> temp_output; // Run each dft of each batch as if it was a real-valued batch size 1 dft operation for (int64_t batch_idx = 0; batch_idx < batch_size; batch_idx++) { for (int64_t i = 0; i < n_dfts; i++) { auto input_frame_begin = - signal_data + - (batch_idx * signal_size * signal_components) + - (i * frame_step * signal_components); + signal_data + (batch_idx * signal_size * signal_components) + (i * frame_step * signal_components); - auto output_frame_begin = - Y_data + - (batch_idx * n_dfts * dft_output_size * output_components) + - (i * dft_output_size * output_components); + auto output_frame_begin = Y_data + (batch_idx * n_dfts * dft_output_size * output_components) + + (i * dft_output_size * output_components); // Tensors do not own the backing memory, so no worries on destruction - auto input = - onnxruntime::Tensor( - signal->DataType(), - dft_input_shape, - input_frame_begin, - signal->Location(), - 0); - - auto output = - onnxruntime::Tensor( - Y->DataType(), - dft_output_shape, - output_frame_begin, - Y->Location(), - 0); + auto input = onnxruntime::Tensor(signal->DataType(), dft_input_shape, input_frame_begin, signal->Location(), 0); + + auto output = onnxruntime::Tensor(Y->DataType(), dft_output_shape, output_frame_begin, Y->Location(), 0); // Run individual dft - ORT_RETURN_IF_ERROR((discrete_fourier_transform(ctx, &input, &output, 1, window_size, window, is_onesided, false, V, temp_output))); + ORT_RETURN_IF_ERROR((discrete_fourier_transform(ctx, &input, &output, 1, window_size, window, is_onesided, + false, V, temp_output))); } } @@ -583,7 +553,11 @@ Status STFT::Compute(OpKernelContext* ctx) const { } else if (is_complex_valued) { ORT_RETURN_IF_ERROR((short_time_fourier_transform>(ctx, is_onesided_, false))); } else { - ORT_THROW("Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for complex inputs.", data_type); + ORT_THROW( + "Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second " + "dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for " + "complex inputs.", + data_type); } } else if (element_size == sizeof(double)) { if (is_real_valued) { @@ -591,7 +565,11 @@ Status STFT::Compute(OpKernelContext* ctx) const { } else if (is_complex_valued) { ORT_RETURN_IF_ERROR((short_time_fourier_transform>(ctx, is_onesided_, false))); } else { - ORT_THROW("Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for complex inputs.", data_type); + ORT_THROW( + "Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second " + "dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for " + "complex inputs.", + data_type); } } else { ORT_THROW("Unsupported input data type of ", data_type); @@ -600,7 +578,4 @@ Status STFT::Compute(OpKernelContext* ctx) const { return Status::OK(); } -} // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/contrib_ops/cpu/signal/dft.h b/onnxruntime/core/providers/cpu/signal/dft.h similarity index 72% rename from onnxruntime/contrib_ops/cpu/signal/dft.h rename to onnxruntime/core/providers/cpu/signal/dft.h index e177eb877ea7b..17f25a8f975a3 100644 --- a/onnxruntime/contrib_ops/cpu/signal/dft.h +++ b/onnxruntime/core/providers/cpu/signal/dft.h @@ -1,15 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef BUILD_MS_EXPERIMENTAL_OPS +#include "core/common/common.h" +#include "core/framework/op_kernel.h" namespace onnxruntime { -namespace contrib { class DFT final : public OpKernel { bool is_onesided_ = true; int64_t axis_ = 0; bool is_inverse_ = false; + public: explicit DFT(const OpKernelInfo& info) : OpKernel(info) { is_onesided_ = static_cast(info.GetAttrOrDefault("onesided", 0)); @@ -19,17 +20,9 @@ class DFT final : public OpKernel { Status Compute(OpKernelContext* ctx) const override; }; -class IDFT final : public OpKernel { - int64_t axis_ = 0; - public: - explicit IDFT(const OpKernelInfo& info) : OpKernel(info) { - axis_ = info.GetAttrOrDefault("axis", 0); - } - Status Compute(OpKernelContext* ctx) const override; -}; - class STFT final : public OpKernel { bool is_onesided_ = true; + public: explicit STFT(const OpKernelInfo& info) : OpKernel(info) { is_onesided_ = static_cast(info.GetAttrOrDefault("onesided", 1)); @@ -37,7 +30,4 @@ class STFT final : public OpKernel { Status Compute(OpKernelContext* ctx) const override; }; -} // namespace contrib } // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/cpu/signal/utils.h b/onnxruntime/core/providers/cpu/signal/utils.h new file mode 100644 index 0000000000000..c3bc949844d03 --- /dev/null +++ b/onnxruntime/core/providers/cpu/signal/utils.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace signal { + +template +static T get_scalar_value_from_tensor(const Tensor* tensor) { + ORT_ENFORCE(tensor->Shape().Size() == 1, "ratio input should have a single value."); + + auto data_type = tensor->GetElementType(); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + return static_cast(*reinterpret_cast(tensor->DataRaw())); + default: + ORT_THROW("Unsupported input data type of ", data_type); + } +} + +} // namespace signal +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/signal/window_functions.cc b/onnxruntime/core/providers/cpu/signal/window_functions.cc new file mode 100644 index 0000000000000..552f930a46e8f --- /dev/null +++ b/onnxruntime/core/providers/cpu/signal/window_functions.cc @@ -0,0 +1,215 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/signal/window_functions.h" + +#include + +#include "core/providers/common.h" +#include "core/providers/cpu/signal/utils.h" + +namespace onnxruntime { +ONNX_CPU_OPERATOR_KERNEL(HannWindow, 17, + KernelDefBuilder() + .MayInplace(0, 0) // + .TypeConstraint("T1", BuildKernelDefConstraints()) // + .TypeConstraint("T2", + BuildKernelDefConstraints()), + HannWindow); + +ONNX_CPU_OPERATOR_KERNEL(HammingWindow, 17, + KernelDefBuilder() + .MayInplace(0, 0) // + .TypeConstraint("T1", BuildKernelDefConstraints()) // + .TypeConstraint("T2", + BuildKernelDefConstraints()), + HammingWindow); + +ONNX_CPU_OPERATOR_KERNEL(BlackmanWindow, 17, + KernelDefBuilder() + .MayInplace(0, 0) // + .TypeConstraint("T1", BuildKernelDefConstraints()) // + .TypeConstraint("T2", + BuildKernelDefConstraints()), + BlackmanWindow); + +ONNX_CPU_OPERATOR_KERNEL(MelWeightMatrix, 17, + KernelDefBuilder() + .MayInplace(0, 0) // + .TypeConstraint("T1", BuildKernelDefConstraints()) // + .TypeConstraint("T2", BuildKernelDefConstraints()) + .TypeConstraint("T3", + BuildKernelDefConstraints()), + MelWeightMatrix); + +template +struct CosineSumWindow { + Status operator()(Tensor* Y, size_t size, float a0, float a1, float a2) { + auto* Y_data = reinterpret_cast(Y->MutableDataRaw()); + + // Calculate the radians to increment per sample + constexpr double pi = 3.14159265; + constexpr double tau = 2 * pi; + const double angular_increment = tau / size; + + for (size_t i = 0; i < size; i++) { + auto a2_component = a2 == 0 ? 0 : (a2 * cos(2 * angular_increment * i)); + + T& value = *(Y_data + i); + value = static_cast(a0 - (a1 * cos(angular_increment * i)) + a2_component); + } + + return Status::OK(); + } +}; + +static Status create_cosine_sum_window(OpKernelContext* ctx, onnx::TensorProto_DataType output_datatype, float a0, + float a1, float a2) { + // Get the size of the window + auto size = signal::get_scalar_value_from_tensor(ctx->Input(0)); + + // Get the output tensor + auto Y_shape = TensorShape({size}); + auto Y = ctx->Output(0, Y_shape); + + utils::MLTypeCallDispatcher + dispatcher(output_datatype); + return dispatcher.InvokeRet(Y, size, a0, a1, a2); +} + +Status HannWindow::Compute(OpKernelContext* ctx) const { + // HannWindows are a special case of Cosine-Sum Windows which take the following form: + // w[n] = SUM_k=0_K( (-1)^k * a_k * cos(2*pi*k*n/N) ) with values the following values for a_k: + float a0 = .5f; + float a1 = a0; + float a2 = 0; + return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); +} + +Status HammingWindow::Compute(OpKernelContext* ctx) const { + // HammingWindows are a special case of Cosine-Sum Windows which take the following form: + // w[n] = SUM_k=0_K( (-1)^k * a_k * cos(2*pi*k*n/N) ) with values the following values for a_k: + float a0 = 25.f / 46.f; + float a1 = 1 - a0; + float a2 = 0; + return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); +} + +Status BlackmanWindow::Compute(OpKernelContext* ctx) const { + // BlackmanWindows are a special case of Cosine-Sum Windows which take the following form: + // w[n] = SUM_k=0_K( (-1)^k * a_k * cos(2*pi*k*n/N) ) with values the following values for a_k: + float alpha = .16f; + float a2 = alpha / 2.f; + float a0 = .5f - a2; + float a1 = .5f; + return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); +} + +static inline double hz_to_mel_scale(double hz) { return 2595 * std::log10(1 + hz / 700); } + +static inline double mel_scale_to_hz(double mels) { return 700 * (pow(10, (mels / 2595)) - 1); } + +template +struct CreateMelWeightMatrix { + Status operator()(OpKernelContext* ctx, int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate, + float lower_edge_hertz, float upper_edge_hertz) { + // Determine the width of the spectrogram. + // This is determined as half the size of the fft size. The first element of the spectrum is always retained, + // and the remaining are halved. The second half can be discarded due to the conjugate symmetry of the output with + // real valued ffts. Taken together the formula for the size of the output will be std::floor(dft_length / 2) + 1. + int64_t num_spectrogram_bins = static_cast(std::floor(dft_length / 2 + 1)); + + // Checks + auto lowest_index = std::floor(((dft_length + 1) * lower_edge_hertz) / sample_rate); + auto highest_index = std::floor(((dft_length + 1) * upper_edge_hertz) / sample_rate); + ORT_ENFORCE( + lowest_index >= 0 && lowest_index < num_spectrogram_bins, + "lower_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the " + "sample_rate."); + ORT_ENFORCE( + highest_index >= 0 && highest_index < num_spectrogram_bins, + "upper_edge_hertz produces a mel triangle filter bank that is out of range given the dft_length and the " + "sample_rate."); + + // Create the output shape + TensorShape output_shape({static_cast(num_spectrogram_bins), num_mel_bins}); + auto* Y = ctx->Output(0, output_shape); + + // Get the raw output data + auto* Y_data = reinterpret_cast(Y->MutableDataRaw()); + + // Set the weight matrix to 0 + memset(Y_data, 0, num_spectrogram_bins * num_mel_bins * sizeof(T)); + + // The mel filterbank is a triangular shaped peak with a height of 1 and a base equal to the size of the MEL range + // divided by the number of bins needed times 2. This triangle is then slid across the mel domain linearly, with a + // constant step size that is equal to half of the base of the triangle. To accommodate N bins, N+2 data points will + // be needed to determine the start, center and end points of each mel triangle filter. + // + // low_frequency where the mel triangle filter banks begin, and they end on the high_frequency_mel + // The range is divided evenly to create the needed points corresponding to the begin, center, end points of each + // triangle filterbank + InlinedVector frequency_bins(num_mel_bins + 2); + auto low_frequency_mel = hz_to_mel_scale(lower_edge_hertz); + auto high_frequency_mel = hz_to_mel_scale(upper_edge_hertz); + auto mel_step = (high_frequency_mel - low_frequency_mel) / static_cast(frequency_bins.size()); + + // Convert each point from mel scale back to hertz, and then compute the corresponding index in the fft + for (size_t i = 0; i < frequency_bins.size(); i++) { + auto hz = mel_scale_to_hz(low_frequency_mel + mel_step * i); + frequency_bins[i] = static_cast(std::floor(((dft_length + 1) * hz) / sample_rate)); + } + + for (size_t i = 0; i < static_cast(num_mel_bins); i++) { + auto lower_frequency_value = frequency_bins[i]; // left + auto center_frequency_point = frequency_bins[i + 1]; // center + auto higher_frequency_point = frequency_bins[i + 2]; // right + + auto low_to_center = center_frequency_point - lower_frequency_value; + if (low_to_center == 0) { + auto& current_element = *(Y_data + (center_frequency_point * num_mel_bins) + i); + current_element = static_cast(1); + } else { + for (size_t j = lower_frequency_value; j <= center_frequency_point; j++) { + auto& current_element = *(Y_data + (j * num_mel_bins) + i); + current_element = static_cast((j - lower_frequency_value) / static_cast(low_to_center)); + } + } + + auto center_to_high = higher_frequency_point - center_frequency_point; + if (center_to_high > 0) { + for (size_t j = center_frequency_point; j < higher_frequency_point; j++) { + auto& current_element = *(Y_data + (j * num_mel_bins) + i); + current_element = static_cast((higher_frequency_point - j) / static_cast(center_to_high)); + } + } + } + + return Status::OK(); + } +}; + +static Status create_mel_weight_matrix(OpKernelContext* ctx, onnx::TensorProto_DataType output_datatype, + int64_t num_mel_bins, int64_t dft_length, int64_t sample_rate, + float lower_edge_hertz, float upper_edge_hertz) { + utils::MLTypeCallDispatcher + dispatcher(output_datatype); + return dispatcher.InvokeRet(ctx, num_mel_bins, dft_length, sample_rate, + lower_edge_hertz, upper_edge_hertz); +} + +Status MelWeightMatrix::Compute(OpKernelContext* ctx) const { + const auto num_mel_bins = signal::get_scalar_value_from_tensor(ctx->Input(0)); + const auto dft_length = signal::get_scalar_value_from_tensor(ctx->Input(1)); + const auto sample_rate = signal::get_scalar_value_from_tensor(ctx->Input(2)); + const auto lower_edge_hertz = signal::get_scalar_value_from_tensor(ctx->Input(3)); + const auto upper_edge_hertz = signal::get_scalar_value_from_tensor(ctx->Input(4)); + + return create_mel_weight_matrix(ctx, data_type_, num_mel_bins, dft_length, sample_rate, lower_edge_hertz, + upper_edge_hertz); +} +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/signal/window_functions.h b/onnxruntime/core/providers/cpu/signal/window_functions.h similarity index 77% rename from onnxruntime/contrib_ops/cpu/signal/window_functions.h rename to onnxruntime/core/providers/cpu/signal/window_functions.h index 81d8d3b48c656..052c3ac43a16a 100644 --- a/onnxruntime/contrib_ops/cpu/signal/window_functions.h +++ b/onnxruntime/core/providers/cpu/signal/window_functions.h @@ -1,18 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef BUILD_MS_EXPERIMENTAL_OPS +#include "core/common/common.h" +#include "core/framework/op_kernel.h" namespace onnxruntime { -namespace contrib { class VariableOutputDataTypeBase : public OpKernel { protected: onnx::TensorProto_DataType data_type_; public: - VariableOutputDataTypeBase(const OpKernelInfo& info) : OpKernel(info) { - data_type_ = static_cast(info.GetAttrOrDefault("output_datatype", onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)); + explicit VariableOutputDataTypeBase(const OpKernelInfo& info) : OpKernel(info) { + data_type_ = static_cast( // + info.GetAttrOrDefault("output_datatype", onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)); } }; @@ -44,7 +45,4 @@ class MelWeightMatrix final : public VariableOutputDataTypeBase { Status Compute(OpKernelContext* ctx) const override; }; -} // namespace contrib } // namespace onnxruntime - -#endif \ No newline at end of file diff --git a/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json b/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json index 399e26fb35fd6..d9271921c1e90 100644 --- a/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json +++ b/onnxruntime/test/testdata/kernel_def_hashes/onnx.cpu.json @@ -299,6 +299,10 @@ "BitShift ai.onnx CPUExecutionProvider", 8765933529403563240 ], + [ + "BlackmanWindow ai.onnx CPUExecutionProvider", + 4230790036355038984 + ], [ "Cast ai.onnx CPUExecutionProvider", 4892631558605514456 @@ -463,6 +467,10 @@ "Det ai.onnx CPUExecutionProvider", 4355346295804324544 ], + [ + "DFT ai.onnx CPUExecutionProvider", + 2809655513372322840 + ], [ "Div ai.onnx CPUExecutionProvider", 3765227735719542728 @@ -911,7 +919,7 @@ "GreaterOrEqual ai.onnx CPUExecutionProvider", 17416867432093505280 ], -[ + [ "GreaterOrEqual ai.onnx CPUExecutionProvider", 4445196831337347808 ], @@ -926,7 +934,7 @@ [ "GreaterOrEqual ai.onnx CPUExecutionProvider", 16172564801671050120 - ], + ], [ "GridSample ai.onnx CPUExecutionProvider", 15150264021585158264 @@ -939,6 +947,14 @@ "GRU ai.onnx CPUExecutionProvider", 2706165712066264784 ], + [ + "HammingWindow ai.onnx CPUExecutionProvider", + 7960927909626268504 + ], + [ + "HannWindow ai.onnx CPUExecutionProvider", + 11998243503561799520 + ], [ "Hardmax ai.onnx CPUExecutionProvider", 3471079605532327368 @@ -1018,7 +1034,7 @@ [ "LeakyRelu ai.onnx CPUExecutionProvider", 830582302303937272 - ], + ], [ "Less ai.onnx CPUExecutionProvider", 2529281912870061232 @@ -1090,7 +1106,7 @@ [ "LessOrEqual ai.onnx CPUExecutionProvider", 15565321713560893128 - ], + ], [ "Log ai.onnx CPUExecutionProvider", 268464912229648680 @@ -1287,6 +1303,10 @@ "MeanVarianceNormalization ai.onnx CPUExecutionProvider", 17242016597551698064 ], + [ + "MelWeightMatrix ai.onnx CPUExecutionProvider", + 1589563865873170600 + ], [ "Min ai.onnx CPUExecutionProvider", 5444634510407971152 @@ -1586,7 +1606,7 @@ [ "PRelu ai.onnx CPUExecutionProvider", 17872917958807301128 - ], + ], [ "QLinearConv ai.onnx CPUExecutionProvider", 1301685544574905024 @@ -2230,7 +2250,7 @@ [ "Scan ai.onnx CPUExecutionProvider", 220271302879298784 - ], + ], [ "Scatter ai.onnx CPUExecutionProvider", 15759064509848656392 @@ -2447,6 +2467,10 @@ "Squeeze ai.onnx CPUExecutionProvider", 16122603335179721968 ], + [ + "STFT ai.onnx CPUExecutionProvider", + 1739051453790648552 + ], [ "StringNormalizer ai.onnx CPUExecutionProvider", 7767393334034626736 @@ -2698,9 +2722,9 @@ [ "Where ai.onnx CPUExecutionProvider", 17544214758602217832 - ], + ], [ "Xor ai.onnx CPUExecutionProvider", 14631049987911195736 ] -] \ No newline at end of file +] From 63a179bd0408e8ee8aceab9fded43293a78cafe9 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Fri, 10 Jun 2022 21:49:29 +0000 Subject: [PATCH 02/20] undo formatting changes to cpu_execution_provider.cc --- .../providers/cpu/cpu_execution_provider.cc | 910 +++++++++--------- 1 file changed, 431 insertions(+), 479 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 0e76fabdfffc9..db76f24209498 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -137,10 +137,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, Softmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, float, TopK); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 9, double, TopK); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, - BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, BatchNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Conv); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 8, Flatten); @@ -161,7 +159,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, int32_t, + ReduceLogSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, @@ -258,28 +257,17 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, Sign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Shrink); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, Erf); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, - int64_t_int64_t_int64_t, OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_int64_t_int64_t, - OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_string_int64_t, - OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_string_int64_t, - OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_float_float, - OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_int32_t_float, - OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int64_t, - OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_int32_t, - OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_float, - OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_float, - OneHot); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int32_t, - OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_int64_t_int64_t, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_int64_t_int64_t, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_string_int64_t, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_string_int64_t, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float_float_float, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_int32_t_float, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int64_t, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_int32_t, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int32_t_float_float, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_float, OneHot); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, int64_t_float_int32_t, OneHot); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MaxUnpool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh); @@ -307,10 +295,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int64_t, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, double, - BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 13, double, BatchNormalization); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 15, PRelu); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, float, Upsample); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 9, int32_t, Upsample); @@ -328,16 +314,11 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, int8_t, Resize); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, uint8_t, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, - DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, - DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, - DequantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, - QuantizeLinear); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, - QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, DequantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, QuantizeLinear); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, QLinearMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, int8_t, QLinearMatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, uint8_t, MatMulInteger); @@ -381,14 +362,10 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceL2); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, - ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, - ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, - ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, - ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceLogSumExp); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, double, ReduceMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 11, int32_t, ReduceMax); @@ -407,12 +384,9 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSum); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int64_t, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, - ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, - ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, - ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, int32_t, ReduceSumSquare); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, Hardmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, LogSoftmax); @@ -506,22 +480,15 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Ei // class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, MLFloat16_float, Dropout); // class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, MLFloat16_double, Dropout); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_float, Dropout); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_double, - Dropout); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_float, - Dropout); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_double, - Dropout); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, float_double, Dropout); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_float, Dropout); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 12, double_double, Dropout); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, Celu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, float, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, double, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int32_t, - GreaterOrEqual); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int64_t, - GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, float, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, double, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int32_t, GreaterOrEqual); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int64_t, GreaterOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, float, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, double, LessOrEqual); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 12, 15, int32_t, LessOrEqual); @@ -739,10 +706,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, int64_t, Div); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, Reshape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 15, Identity); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float, - BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, double, - BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, GRU); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, LSTM); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 14, RNN); @@ -838,123 +803,95 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -985,8 +922,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { double, Equal)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -994,10 +930,8 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1117,28 +1055,31 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 9 BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1253,36 +1196,36 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { int64_t, NonZero)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1311,30 +1254,27 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // opset 11 - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1366,8 +1307,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { float, Equal)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1430,17 +1366,15 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, // REVIEW(codemzs): ConstEigenVectorArrayMap.cast, - // BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 13 BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1912,41 +1859,53 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Softmax)>, // OpSet 14 - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -1980,12 +1939,9 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2005,13 +1961,14 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { LessOrEqual)>, BuildKernelCreateInfo, + // Opset 17 BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { @@ -2058,18 +2015,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, int32_t, Scaler); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, SVMClassifier); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, SVMRegressor); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, - TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, - TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int64_t, - TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int32_t, - TreeEnsembleClassifier); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, - TreeEnsembleRegressor); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, - TreeEnsembleRegressor); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int64_t, TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, int32_t, TreeEnsembleClassifier); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, float, TreeEnsembleRegressor); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, 2, double, TreeEnsembleRegressor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 1, ZipMap); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMLDomain, 2, float_string, LabelEncoder); @@ -2137,25 +2088,26 @@ Status RegisterOnnxMLOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo CPUExecutionProvider::GetKernelRegistry() const std::unique_ptr CPUExecutionProvider::GetDataTransfer() const { return std::make_unique(); } -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file From 718faeade696be9f23d0033e616ef91366811ae1 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Fri, 10 Jun 2022 21:50:51 +0000 Subject: [PATCH 03/20] add trailing new line --- onnxruntime/core/providers/cpu/cpu_execution_provider.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index db76f24209498..2c5d9abee64ac 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -2180,4 +2180,4 @@ std::shared_ptr CPUExecutionProvider::GetKernelRegistry() const std::unique_ptr CPUExecutionProvider::GetDataTransfer() const { return std::make_unique(); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime From 183477e71876e75ded10beb85ef4154990235cdc Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Tue, 14 Jun 2022 17:01:08 +0000 Subject: [PATCH 04/20] code review comments --- .../providers/cpu/cpu_execution_provider.cc | 2 +- onnxruntime/core/providers/cpu/signal/dft.cc | 26 +++++++------------ onnxruntime/core/providers/cpu/signal/dft.h | 2 +- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 2c5d9abee64ac..191f34439c7bf 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1968,7 +1968,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index 8223b3bfa9e44..c4a5bf680eaaf 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -192,8 +192,8 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s V = InlinedVector>(dft_length); // e^(i *2*pi / N * k) for (size_t i = 0; i < dft_length; i++) { size_t bit_reversed_index = bit_reverse(i, significant_bits); - V[bit_reversed_index] = std::complex(cos(static_cast(i) * angular_velocity), - sin(static_cast(i) * angular_velocity)); + const T angle = static_cast(i) * angular_velocity; + V[bit_reversed_index] = std::complex(cos(angle), sin(angle)); } } @@ -270,9 +270,8 @@ static Status dft_naive(const Tensor* X, Tensor* Y, size_t X_offset, size_t X_st out.imag(0); for (size_t j = 0; j < dft_length; j++) { // vectorize over this loop - auto exponential = std::complex( - cos(static_cast(i) * static_cast(j) * angular_velocity), - sin(static_cast(i) * static_cast(j) * angular_velocity)); + const T angle = static_cast(i) * static_cast(j) * angular_velocity; + auto exponential = std::complex(cos(angle), sin(angle)); auto window_element = window_data ? *(window_data + j) : 1; auto x = (j < number_of_samples) ? *(X_data + j * X_stride) : 0; auto element = x * window_element; @@ -355,15 +354,7 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo const auto& X_shape = X->Shape(); const auto is_real_valued = is_real_valued_signal(X_shape); const auto is_complex_valued = is_complex_valued_signal(X_shape); - - // Get the rank of the input tensor - // Ensure that the axis is in the valid range of [-rank, rank) - auto rank = static_cast(X_shape.GetDims().size()); - if (!(-rank <= axis && axis < rank)) { - ORT_RETURN_IF(!(-rank <= axis && axis < rank), "axis attribute value ", axis, " is invalid for a tensor of rank ", - rank); - } - axis = (axis >= 0 ? axis : axis + rank); + axis = HandleNegativeAxis(axis, X_shape.NumDimensions()); int64_t number_of_samples = static_cast(X_shape[axis]); if (dft_length) { @@ -402,7 +393,7 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo ctx, X, Y, axis, number_of_samples, nullptr, is_onesided, inverse, V, temp_output))); } else { ORT_THROW( - "Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second " + "Unsupported input signal shape. The signal's first dimension must be the batch dimension and its second " "dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for " "complex inputs.", data_type); @@ -418,7 +409,7 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo ctx, X, Y, axis, number_of_samples, nullptr, is_onesided, inverse, V, temp_output))); } else { ORT_THROW( - "Unsupported input signal shape. The signal's first dimenstion must be the batch dimension and its second " + "Unsupported input signal shape. The signal's first dimension must be the batch dimension and its second " "dimension must be the signal length dimension. It may optionally include a 3rd dimension of size 2 for " "complex inputs.", data_type); @@ -457,7 +448,8 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside const auto signal_components = signal_shape.NumDimensions() == 2 ? 1 : signal_shape.NumDimensions() == 3 ? signal_shape[2] : 0; // error - ORT_ENFORCE(signal_components == 1 || signal_components == 2, "Ensure that the signal has either 1 or 2 components."); + ORT_ENFORCE(signal_components == 1 || signal_components == 2, + "signal shape must end in 1 (real) or 2 (real, imaginary)."); // Get the frame length int64_t frame_length = std::numeric_limits::min(); diff --git a/onnxruntime/core/providers/cpu/signal/dft.h b/onnxruntime/core/providers/cpu/signal/dft.h index 17f25a8f975a3..71cac52e37e8f 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.h +++ b/onnxruntime/core/providers/cpu/signal/dft.h @@ -14,7 +14,7 @@ class DFT final : public OpKernel { public: explicit DFT(const OpKernelInfo& info) : OpKernel(info) { is_onesided_ = static_cast(info.GetAttrOrDefault("onesided", 0)); - axis_ = info.GetAttrOrDefault("axis", 0); + axis_ = info.GetAttrOrDefault("axis", 1); is_inverse_ = info.GetAttrOrDefault("inverse", 0); } Status Compute(OpKernelContext* ctx) const override; From 6f925e132c7ef3a200e4decaebddb6a91fab990e Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Wed, 15 Jun 2022 23:55:27 +0000 Subject: [PATCH 05/20] Delete experimental op schemas. --- .../core/graph/contrib_ops/contrib_defs.cc | 7 - .../core/graph/signal_ops/signal_defs.cc | 636 ------------------ .../core/graph/signal_ops/signal_defs.h | 36 - 3 files changed, 679 deletions(-) delete mode 100644 onnxruntime/core/graph/signal_ops/signal_defs.cc delete mode 100644 onnxruntime/core/graph/signal_ops/signal_defs.h diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 8cf2d278e0ead..fb7bc16cc190a 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -16,7 +16,6 @@ #include "core/graph/contrib_ops/range_schema_defs.h" #include "core/graph/op.h" #include "core/mlas/inc/mlas.h" -#include "core/graph/signal_ops/signal_defs.h" #include "core/graph/contrib_ops/onnx_function_util.h" #include "onnx/defs/function.h" @@ -370,7 +369,6 @@ void sparseCompatibleMatmulShapeInference( updateOutputShape(ctx, 0, resultShape, default_tensor_type); } - bool ParseScalar(const TensorProto* initializer, int& value) { std::vector parsed_data; if (initializer->data_type() == TensorProto::INT32) { @@ -2417,7 +2415,6 @@ void RegisterContribSchemas() { // } // updateOutputShape(ctx, 0, disentangled_attention_shape); propagateShapeFromInputToOutput(ctx, 0, 0); - }); ONNX_CONTRIB_OPERATOR_SCHEMA(Snpe) @@ -2535,10 +2532,6 @@ This op functions in much the same was as Dropout-11 and Dropout-13 do, execpt t RegisterNchwcSchemas(); } #endif - -#ifdef BUILD_MS_EXPERIMENTAL_OPS - onnxruntime::signal::RegisterSignalSchemas(); -#endif } } // namespace contrib diff --git a/onnxruntime/core/graph/signal_ops/signal_defs.cc b/onnxruntime/core/graph/signal_ops/signal_defs.cc deleted file mode 100644 index ffca16754319e..0000000000000 --- a/onnxruntime/core/graph/signal_ops/signal_defs.cc +++ /dev/null @@ -1,636 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef BUILD_MS_EXPERIMENTAL_OPS - -#include "core/graph/signal_ops/signal_defs.h" - -#include -#include - -#include "core/framework/tensorprotoutils.h" -#include "core/providers/common.h" -#include "core/graph/constants.h" -#include "core/graph/op.h" -#include "onnx/defs/schema.h" -#include "onnx/defs/shape_inference.h" -#include "onnx/defs/tensor_proto_util.h" - -// NOTE: These were added to the standard op set. We register them under the MS domain -// for backwards compatibility, but new users should use the standard ops instead. Ideally these would be deleted. -namespace onnxruntime { -namespace signal { - -using ONNX_NAMESPACE::AttributeProto; -using ONNX_NAMESPACE::OpSchema; -using ONNX_NAMESPACE::OPTIONAL_VALUE; - -template -static T get_scalar_value_from_tensor(const ONNX_NAMESPACE::TensorProto* t) { - if (t == nullptr) { - return T{}; - } - - auto data_type = t->data_type(); - switch (data_type) { - case ONNX_NAMESPACE::TensorProto::FLOAT: - return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); - case ONNX_NAMESPACE::TensorProto::DOUBLE: - return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); - case ONNX_NAMESPACE::TensorProto::INT32: - return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); - case ONNX_NAMESPACE::TensorProto::INT64: - return static_cast(ONNX_NAMESPACE::ParseData(t).at(0)); - default: - ORT_THROW("Unsupported input data type of ", data_type); - } -} - -inline const ONNX_NAMESPACE::TensorShapeProto* getOptionalInputShape(ONNX_NAMESPACE::InferenceContext& ctx, size_t n) { - const auto* input_type = ctx.getInputType(n); - - if (input_type == nullptr) { - return nullptr; - } - - const auto value_case = input_type->value_case(); - if (value_case != ONNX_NAMESPACE::TypeProto::kTensorType && - value_case != ONNX_NAMESPACE::TypeProto::kSparseTensorType) { - fail_type_inference("Attribute expected to have tensor or sparse tensor type"); - } - if (value_case == ONNX_NAMESPACE::TypeProto::kTensorType) { - return &input_type->tensor_type().shape(); - } else { - return &input_type->sparse_tensor_type().shape(); - } -} - -std::function CosineSumWindowOpDocGenerator(const char* name) { - return [name](OpSchema& schema) { - std::string doc; - POPULATE_OP_DOC_STR(doc = R"DOC( -Generates a {name} window as described in the paper https://ieeexplore.ieee.org/document/1455106. -)DOC"; - ReplaceAll(doc, "{name}", name);); - - schema.SetDoc(doc); - schema.Attr("output_datatype", - "The data type of the output tensor. " - "Strictly must be one of the values from DataType enum in TensorProto whose values correspond to T2. " - "The default value is 1 = FLOAT. ", - AttributeProto::INT, static_cast(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)); - schema.Attr("periodic", - "If 1, returns a window to be used as periodic function. If 0, return a symmetric window. " - "When 'periodic' is specified, hann computes a window of length size + 1 and returns the first size " - "points. The default value is 1. ", - AttributeProto::INT, static_cast(1)); - schema.Input(0, "size", "A scalar value indicating the length of the window.", "T1", OpSchema::Single, true, 1, - OpSchema::NonDifferentiable); - schema.Output(0, "output", - "A Hann window with length: size. " - "The output has the shape: [size].", - "T2", OpSchema::Single, true, 1, OpSchema::NonDifferentiable); - schema.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - // Update the output data type to the output_datatype - auto output_datatype = getAttribute(ctx, "output_datatype", - static_cast(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)); - updateOutputElemType(ctx, 0, static_cast(output_datatype)); - - if (!hasInputShape(ctx, 0)) { - // If no shape is available for the input, skip shape inference. - return; - } - - const auto* size = ctx.getInputData(0); - if (size == nullptr) { - // Size is not available, so return early - return; - } - - if (size->dims_size() != 0) { - fail_shape_inference("size input must be a scalar."); - } - - auto size_value = get_scalar_value_from_tensor(size); - if (size_value <= 0) { - fail_shape_inference("size input must be greater than 0."); - } - - ONNX_NAMESPACE::TensorShapeProto result_shape; - result_shape.add_dim()->set_dim_value(size_value); - updateOutputShape(ctx, 0, result_shape); - }); - }; -} - -ONNX_NAMESPACE::NodeProto WindowOpFunctionNode(const char* name) { - ONNX_NAMESPACE::NodeProto node; - node.set_op_type(std::string(name) + "Window"); - - auto* output_datatype = node.add_attribute(); - output_datatype->set_name("output_datatype"); - output_datatype->set_ref_attr_name("output_datatype"); - - auto* periodic = node.add_attribute(); - periodic->set_name("periodic"); - periodic->set_ref_attr_name("periodic"); - - node.add_input("size"); - node.add_output("output"); - return node; -} - -void RegisterSignalSchemas() { - ONNX_NAMESPACE::NodeProto dft_function_node; - dft_function_node.set_op_type("DFT"); - - auto* dft_function_onesided = dft_function_node.add_attribute(); - dft_function_onesided->set_name("onesided"); - dft_function_onesided->set_ref_attr_name("onesided"); - - auto* dft_function_axis = dft_function_node.add_attribute(); - dft_function_axis->set_name("axis"); - dft_function_axis->set_ref_attr_name("axis"); - - auto* dft_function_inverse = dft_function_node.add_attribute(); - dft_function_inverse->set_name("inverse"); - dft_function_inverse->set_ref_attr_name("inverse"); - - dft_function_node.add_input("input"); - dft_function_node.add_input("dft_length"); - dft_function_node.add_output("output"); - - ONNX_NAMESPACE::OperatorSetIdProto onnx_op_set_17; - onnx_op_set_17.set_domain(kOnnxDomain); - onnx_op_set_17.set_version(17); - - MS_SIGNAL_OPERATOR_SCHEMA(DFT) - .SetDomain(kMSExperimentalDomain) - .SinceVersion(1) - .SetDoc(R"DOC(DFT)DOC") - .Attr("onesided", - "If True (default), only values for half of the fft size are returned because the real-to-complex Fourier " - "transform satisfies the conjugate symmetry.The output tensor will return the first floor(n_fft/2) + 1 " - "values from the DFT. Values can be 0 or 1.", - AttributeProto::AttributeType::AttributeProto_AttributeType_INT, static_cast(0)) - .Attr("axis", - "The axis on which to perform the DFT. By default this value is set to 0, which corresponds to the first " - "dimension after the batch index. This value must be less than signal_dimN, where signal_dimN is the " - "number of dimensions in the signal.", - AttributeProto::AttributeType::AttributeProto_AttributeType_INT, static_cast(0)) - .Attr("inverse", - "Whether to perform the inverse discrete fourier transform. By default this value is set to 0, which " - "corresponds to false.", - AttributeProto::INT, static_cast(0)) - .Input(0, "input", - "For real input, the following shape is expected: [batch_idx][signal_dim1][signal_dim2]..." - "[signal_dimN][1]. For complex input, the following shape is expected: " - "[batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]. The first dimension is the batch dimension. " - "The following N dimensions correspond to the signal's dimensions. " - "The final dimension represents the real and imaginary parts of the value in that order.", - "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) - .Input(1, "dft_length", - "The length of the signal." - "If greater than the axis dimension, the signal will be zero-padded up to dft_length. " - "If less than the axis dimension, only the first dft_length values will be used as the signal. " - "It's an optional value. ", - "T2", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) - .Output(0, "output", - "The Fourier Transform of the input vector." - "If onesided is 0, the following shape is expected: " - "[batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]. " - "If axis=0 and onesided is 1, the following shape is expected: " - "[batch_idx][floor(signal_dim1/2)+1][signal_dim2]...[signal_dimN][2]. " - "If axis=1 and onesided is 1, the following shape is expected: " - "[batch_idx][signal_dim1][floor(signal_dim2/2)+1]...[signal_dimN][2]. " - "If axis=N-1 and onesided is 1, the following shape is expected: " - "[batch_idx][signal_dim1][signal_dim2]...[floor(signal_dimN/2)+1][2]. " - "The signal_dim at the specified axis is equal to the dft_length.", - "T1") - .TypeConstraint("T1", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, - "Constrain input and output types to float tensors.") - .TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "Constrain scalar length types to int64_t.") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - bool is_onesided = static_cast(getAttribute(ctx, "onesided", 0)); - bool inverse = static_cast(getAttribute(ctx, "inverse", 0)); - - if (inverse && is_onesided) { - fail_shape_inference("is_onesided and inverse attributes cannot be enabled at the same time"); - } - - propagateElemTypeFromInputToOutput(ctx, 0, 0); - if (!hasInputShape(ctx, 0)) { - // If no shape is available for the input, skip shape inference... - return; - } - - // In general the output shape will match the input shape exactly - // So initialize the output shape with the input shape - auto& input_shape = getInputShape(ctx, 0); - ONNX_NAMESPACE::TensorShapeProto result_shape_proto = input_shape; - - // Get the axis where the DFT will be performed. - auto axis = static_cast(getAttribute(ctx, "axis", 1)); - auto rank = input_shape.dim_size(); - - if (!(-rank <= axis && axis < rank)) { - fail_shape_inference("axis attribute value ", axis, " is invalid for a tensor of rank ", rank); - } - - auto axis_idx = (axis >= 0 ? axis : axis + rank); - - // If dft_length is specified, then we should honor the shape. - // Set the output dimension to match the dft_length on the axis. - // If onesided this will be adjusted later on... - const ONNX_NAMESPACE::TensorProto* dft_length = nullptr; - if (ctx.getNumInputs() >= 2 && ctx.getInputType(1) != nullptr) { - dft_length = ctx.getInputData(1); - if (dft_length == nullptr) { - // If we cannot read the dft_length, we cannot infer shape - // return... - return; - } - } - - if (nullptr != dft_length) { - if (dft_length->dims_size() != 0) { - fail_shape_inference("dft_length input must be a scalar."); - } - auto dft_length_value = get_scalar_value_from_tensor(dft_length); - result_shape_proto.mutable_dim(axis_idx)->set_dim_value(dft_length_value); - } - // When DFT is onesided, the output shape is half the size of the input shape - // along the specified axis. - if (is_onesided) { - auto axis_dimension = result_shape_proto.dim(axis_idx); - // We need to update the output shape dimension along the specified axis, - // but sometimes the dimension will be a free dimension or be otherwise unset. - // Only perform inference when a input dimension value exists. - if (axis_dimension.has_dim_value()) { - auto original_signal_size = axis_dimension.dim_value(); - auto half_signal_size = (original_signal_size >> 1) + 1; - result_shape_proto.mutable_dim(axis_idx)->set_dim_value(half_signal_size); - } else { - // Clear the value and param (which would otherwie be inherited from the input). - result_shape_proto.mutable_dim(axis_idx)->clear_dim_value(); - result_shape_proto.mutable_dim(axis_idx)->clear_dim_param(); - } - } - - // Coerce the last dimension to 2. - auto dim_size = static_cast(result_shape_proto.dim_size()); - auto has_component_dimension = dim_size > 2; - - // This if check is retained in the contrib op and not the official spec for back compat - if (has_component_dimension) { - result_shape_proto.mutable_dim(static_cast(dim_size - 1))->set_dim_value(2); - } else { - result_shape_proto.add_dim()->set_dim_value(2); - } - - updateOutputShape(ctx, 0, result_shape_proto); - }) - .FunctionBody({dft_function_node}, {onnx_op_set_17}); - - ONNX_NAMESPACE::NodeProto idft_function_node; - idft_function_node.set_op_type("DFT"); - - auto* idft_function_inverse = idft_function_node.add_attribute(); - idft_function_inverse->set_name("inverse"); - idft_function_inverse->set_i(1); - - auto* idft_function_axis = idft_function_node.add_attribute(); - idft_function_axis->set_name("axis"); - idft_function_axis->set_ref_attr_name("axis"); - - idft_function_node.add_input("input"); - idft_function_node.add_input("dft_length"); - idft_function_node.add_output("output"); - - MS_SIGNAL_OPERATOR_SCHEMA(IDFT) - .SetDomain(kMSExperimentalDomain) - .SinceVersion(1) - .SetDoc(R"DOC(IDFT)DOC") - .Attr("axis", - "The axis on which to perform the DFT. By default this value is set to 0, which corresponds to the first " - "dimension after the batch index." - "This value must be less than signal_dimN, where signal_dimN is the number of dimensions in the signal.", - AttributeProto::AttributeType::AttributeProto_AttributeType_INT, static_cast(0)) - .Input(0, "input", - "For real multi-dimensional input, the following shape is expected: " - "[batch_idx][signal_dim1][signal_dim2]...[signal_dimN][1]." - "For complex multi-dimensional input, the following shape is expected: " - "[batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]." - "The first dimension is the batch dimension." - "The final dimension represents the real and imaginary parts of the value.", - "T1") - .Input(1, "dft_length", - "The length of the signal." - "If greater than the axis dimension, the signal will be zero-padded up to dft_length. " - "If less than the axis dimension, only the first dft_length values will be used as the signal. " - "It's an optional value. ", - "T2", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) - .Output(0, "output", - "The inverse discrete Fourier transform of the input. " - "The signal_dim at the specified axis is equal to the dft_length." - "The expected shape is [batch_idx][signal_dim1][signal_dim2]...[signal_dimN][2]" - "For all types of input, the last dimension of the output represents the components of a complex number.", - "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) - .TypeConstraint("T1", {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, - "Constrain input and output types to float tensors.") - .TypeConstraint("T2", {"tensor(int64)"}, "Constrain scalar length types to int64_t.") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - auto& input_shape = getInputShape(ctx, 0); - ONNX_NAMESPACE::TensorShapeProto result_shape = input_shape; - auto dim_size = static_cast(input_shape.dim_size()); - auto has_component_dimension = dim_size > 2; - - if (has_component_dimension) { - result_shape.mutable_dim(static_cast(dim_size - 1))->set_dim_value(2); - } else { - result_shape.add_dim()->set_dim_value(2); - } - - updateOutputShape(ctx, 0, result_shape); - }) - .FunctionBody({idft_function_node}, {onnx_op_set_17}); - - ONNX_NAMESPACE::NodeProto stft_function_node; - stft_function_node.set_op_type("STFT"); - - auto* stft_function_onesided = idft_function_node.add_attribute(); - stft_function_onesided->set_name("onesided"); - stft_function_onesided->set_ref_attr_name("onesided"); - - stft_function_node.add_input("signal"); - stft_function_node.add_input("frame_step"); - stft_function_node.add_input("window"); - stft_function_node.add_input("frame_length"); - stft_function_node.add_output("output"); - - MS_SIGNAL_OPERATOR_SCHEMA(STFT) - .SetDomain(kMSExperimentalDomain) - .SinceVersion(1) - .SetDoc(R"DOC(STFT)DOC") - .Attr("onesided", - "If onesided is 1, only values for w in [0, 1, 2, ..., floor(n_fft/2) + 1] are returned because " - "the real-to-complex Fourier transform satisfies the conjugate symmetry, i.e., X[m, w] = X[m,w] = " - "X[m,n_fft-w]*. Note if the input or window tensors are complex, then onesided output is not possible. " - "Enabling onesided with real inputs performs a Real-valued fast Fourier transform (RFFT)." - "When invoked with real or complex valued input, the default value is 1. " - "Values can be 0 or 1.", - AttributeProto::INT, static_cast(1)) - .Input(0, "signal", - "Input tensor representing a real or complex valued signal. " - "For real input, the following shape is expected: [batch_size][signal_length][1]. " - "For complex input, the following shape is expected: [batch_size][signal_length][2], where " - "[batch_size][signal_length][0] represents the real component and [batch_size][signal_length][1] " - "represents the imaginary component of the signal.", - "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) - .Input(1, "frame_step", "The number of samples to step between successive DFTs.", "T2", OpSchema::Single, true, 1, - OpSchema::NonDifferentiable) - .Input(2, "window", - "A tensor representing the window that will be slid over the signal." - "The window must have rank 1 with shape: [window_shape]. " - "It's an optional value. ", - "T1", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) - .Input(3, "frame_length", - "A scalar representing the size of the DFT. " - "It's an optional value.", - "T2", OpSchema::Optional, true, 1, OpSchema::NonDifferentiable) - .Output(0, "output", - "The Short-time Fourier Transform of the signals." - "If onesided is 1, the output has the shape: [batch_size][frames][dft_unique_bins][2], where " - "dft_unique_bins is frame_length // 2 + 1 (the unique components of the DFT) " - "If onesided is 0, the output has the shape: [batch_size][frames][frame_length][2], where frame_length " - "is the length of the DFT.", - "T1", OpSchema::Single, true, 1, OpSchema::NonDifferentiable) - .TypeConstraint("T1", {"tensor(float)", "tensor(float16)", "tensor(double)", "tensor(bfloat16)"}, - "Constrain signal and output to float tensors.") - .TypeConstraint("T2", {"tensor(int32)", "tensor(int64)"}, "Constrain scalar length types to int64_t.") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - propagateElemTypeFromInputToOutput(ctx, 0, 0); - - // Get signal size - // The signal size is needed to perform inference because the size of the signal - // is needed to compute the number of DFTs in the output. - // - // 1) Check if shape exists, return if not - // 2) Get the shape - // 3) Check if signal dim value exists, return if not - if (!hasInputShape(ctx, 0)) { - return; - } - - auto& input_shape = getInputShape(ctx, 0); - auto signal_dim = input_shape.dim(1); - if (!signal_dim.has_dim_value()) { - return; - } - auto signal_size = signal_dim.dim_value(); - - // The frame step is a required input. - // Its value is needed to compute the number output nDFTs, so return early is missing. - const auto* frame_step = ctx.getInputData(1); - if (nullptr == frame_step) { - return; - } - auto frame_step_value = get_scalar_value_from_tensor(frame_step); - - // Determine the size of the DFT based on the 2 optional inputs window and frame_length. - // One must be set. - int64_t dft_size = -1; - const ONNX_NAMESPACE::TensorProto* frame_length = nullptr; - if (ctx.getNumInputs() >= 4 && ctx.getInputType(3) != nullptr) { - frame_length = ctx.getInputData(3); - if (frame_length == nullptr) { - // If we cannot read the frame_length, we cannot infer shape - // return... - return; - } - } - - const ONNX_NAMESPACE::TensorShapeProto* window_shape = nullptr; - if (ctx.getNumInputs() >= 3) { - window_shape = ONNX_NAMESPACE::getOptionalInputShape(ctx, 2); - } else { - window_shape = nullptr; - } - - if (window_shape == nullptr && frame_length == nullptr) { - // STFT expects to have at least one of these inputs set: [window, frame_length], - // but they may not be available at shape inference time - return; - } else if (window_shape != nullptr && frame_length != nullptr) { - if (frame_length->dims_size() != 0) { - fail_shape_inference("frame_length input must be scalar."); - } - auto frame_length_value = get_scalar_value_from_tensor(frame_length); - - // Ensure that the window length and the dft_length match. - if (window_shape->dim_size() != 1) { - fail_shape_inference("window input must have rank = 1."); - } - if (window_shape->dim(0).has_dim_value()) { - auto window_length = window_shape->dim(0).dim_value(); - if (window_length != frame_length_value) { - fail_type_inference( - "If STFT has both a window input and frame_length specified, the dimension of the " - "window must match the frame_length specified!"); - } - } - - dft_size = frame_length_value; - } else if (window_shape != nullptr) { - // Ensure that the window length and the dft_length match. - if (window_shape->dim_size() != 1) { - fail_shape_inference("window input must have rank = 1."); - } - if (window_shape->dim(0).has_dim_value()) { - dft_size = window_shape->dim(0).dim_value(); - } else { - // Cannot determine the window size, and there is no frame_length, - // So shape inference cannot proceed. - return; - } - } else if (frame_length != nullptr) { - if (frame_length->dims_size() != 0) { - fail_shape_inference("frame_length input must be scalar."); - } - dft_size = get_scalar_value_from_tensor(frame_length); - } - - bool is_onesided = static_cast(getAttribute(ctx, "onesided", 0)); - if (is_onesided) { - dft_size = is_onesided ? ((dft_size >> 1) + 1) : dft_size; - } - - auto n_dfts = static_cast((signal_size - dft_size) / static_cast(frame_step_value)) + 1; - - // The output has the following shape: [batch_size][frames][dft_unique_bins][2] - ONNX_NAMESPACE::TensorShapeProto result_shape_proto; - result_shape_proto.add_dim()->set_dim_value(input_shape.dim(0).dim_value()); // batch size - result_shape_proto.add_dim()->set_dim_value(n_dfts); - result_shape_proto.add_dim()->set_dim_value(dft_size); - result_shape_proto.add_dim()->set_dim_value(2); - updateOutputShape(ctx, 0, result_shape_proto); - }) - .FunctionBody({stft_function_node}, {onnx_op_set_17}); - - // Window Functions - MS_SIGNAL_OPERATOR_SCHEMA(HannWindow) - .SetDomain(kMSExperimentalDomain) - .SinceVersion(1) - .FillUsing(CosineSumWindowOpDocGenerator("Hann")) - .TypeConstraint("T1", {"tensor(int32)", "tensor(int64)"}, "Constrain the input size to int64_t.") - .TypeConstraint("T2", ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), - "Constrain output types to numeric tensors.") - .FunctionBody({WindowOpFunctionNode("Hann")}, {onnx_op_set_17}); - - MS_SIGNAL_OPERATOR_SCHEMA(HammingWindow) - .SetDomain(kMSExperimentalDomain) - .SinceVersion(1) - .FillUsing(CosineSumWindowOpDocGenerator("Hamming")) - .TypeConstraint("T1", {"tensor(int32)", "tensor(int64)"}, "Constrain the input size to int64_t.") - .TypeConstraint("T2", ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), - "Constrain output types to numeric tensors.") - .FunctionBody({WindowOpFunctionNode("Hamming")}, {onnx_op_set_17}); - - MS_SIGNAL_OPERATOR_SCHEMA(BlackmanWindow) - .SetDomain(kMSExperimentalDomain) - .SinceVersion(1) - .FillUsing(CosineSumWindowOpDocGenerator("Blackman")) - .TypeConstraint("T1", {"tensor(int32)", "tensor(int64)"}, "Constrain the input size to int64_t.") - .TypeConstraint("T2", ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), - "Constrain output types to numeric tensors.") - .FunctionBody({WindowOpFunctionNode("Blackman")}, {onnx_op_set_17}); - - ONNX_NAMESPACE::NodeProto mel_function_node; - stft_function_node.set_op_type("MelWeightMatrix"); - - auto* mel_function_output_datatype = idft_function_node.add_attribute(); - mel_function_output_datatype->set_name("output_datatype"); - mel_function_output_datatype->set_ref_attr_name("output_datatype"); - - mel_function_node.add_input("num_mel_bins"); - mel_function_node.add_input("dft_length"); - mel_function_node.add_input("sample_rate"); - mel_function_node.add_input("lower_edge_hertz"); - mel_function_node.add_input("upper_edge_hertz"); - mel_function_node.add_output("output"); - - static const char* MelWeightMatrix_doc = R"DOC( -Generate a MelWeightMatrix that can be used to re-weight a Tensor containing a linearly sampled frequency spectra -(from DFT or STFT) into num_mel_bins frequency information based on the [lower_edge_hertz, upper_edge_hertz] range on -the mel scale. This function defines the mel scale in terms of a frequency in hertz according to the following formula: - - mel(f) = 2595 * log10(1 + f/700) - -In the returned matrix, all the triangles (filterbanks) have a peak value of 1.0. - -The returned MelWeightMatrix can be used to right-multiply a spectrogram S of shape [frames, num_spectrogram_bins] of -linear scale spectrum values (e.g. STFT magnitudes) to generate a "mel spectrogram" M of shape [frames, num_mel_bins]. -)DOC"; - - MS_SIGNAL_OPERATOR_SCHEMA(MelWeightMatrix) - .SetDomain(kMSExperimentalDomain) - .SinceVersion(1) - .SetDoc(MelWeightMatrix_doc) - .Attr("output_datatype", - "The data type of the output tensor. " - "Strictly must be one of the types from DataType enum in TensorProto.", - ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT, - static_cast(onnx::TensorProto_DataType::TensorProto_DataType_FLOAT)) - .Input(0, "num_mel_bins", "The number of bands in the mel spectrum.", "T1") - .Input(1, "dft_length", "The size of the FFT.", "T1") - .Input(2, "sample_rate", "", "T1") - .Input(3, "lower_edge_hertz", "", "T2") - .Input(4, "upper_edge_hertz", "", "T2") - .Output(0, "output", "The MEL Matrix", "T3") - .TypeConstraint("T1", {"tensor(int32)", "tensor(int64)"}, "Constrain to integer tensors.") - .TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(double)", "tensor(bfloat16)"}, - "Constrain to float tensors") - .TypeConstraint("T3", ONNX_NAMESPACE::OpSchema::all_numeric_types_with_bfloat(), - "Constrain to any numerical types.") - .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - auto output_datatype = getAttribute( - ctx, "output_datatype", static_cast(onnx::TensorProto::DataType::TensorProto_DataType_FLOAT)); - updateOutputElemType(ctx, 0, static_cast(output_datatype)); - - if (!hasInputShape(ctx, 0) || !hasInputShape(ctx, 1)) { - return; - } - - const auto* num_mel_bins = ctx.getInputData(0); - const auto* dft_length = ctx.getInputData(1); - if (nullptr == num_mel_bins || nullptr == dft_length) { - return; - } - - int64_t num_mel_bins_value = -1; - int64_t dft_length_value = -1; - if (num_mel_bins->dims_size() != 0) { - fail_shape_inference("num_mel_bins input must be scalar."); - } - num_mel_bins_value = get_scalar_value_from_tensor(num_mel_bins); - - if (dft_length->dims_size() != 0) { - fail_shape_inference("dft_length input must be scalar."); - } - dft_length_value = get_scalar_value_from_tensor(dft_length); - - if (num_mel_bins_value > 0 && dft_length_value > 0) { - ONNX_NAMESPACE::TensorShapeProto result_shape; - result_shape.add_dim()->set_dim_value(static_cast((dft_length_value >> 1) + 1)); - result_shape.add_dim()->set_dim_value(num_mel_bins_value); - updateOutputShape(ctx, 0, result_shape); - } - }) - .FunctionBody({mel_function_node}, {onnx_op_set_17}); -} - -} // namespace signal -} // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/graph/signal_ops/signal_defs.h b/onnxruntime/core/graph/signal_ops/signal_defs.h deleted file mode 100644 index 6960ff33f6e61..0000000000000 --- a/onnxruntime/core/graph/signal_ops/signal_defs.h +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#if !defined(ORT_MINIMAL_BUILD) -#include "onnx/defs/schema.h" -#else -#include "onnx/defs/data_type_utils.h" -#endif -#include "onnx/onnx_pb.h" -#include "onnx/onnx-operators_pb.h" - -namespace onnxruntime { -namespace signal { -#define MS_SIGNAL_OPERATOR_SCHEMA(name) \ - MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_HELPER(__COUNTER__, name) -#define MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_HELPER(Counter, name) \ - MS_SIGNAL_OPERATOR_SCHEMA_UNIQ(Counter, name) -#define MS_SIGNAL_OPERATOR_SCHEMA_UNIQ(Counter, name) \ - static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \ - op_schema_register_once##name##Counter) ONNX_UNUSED = \ - ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__) - -#define MS_SIGNAL_OPERATOR_SCHEMA_ELSEWHERE(name, schema_func) \ - MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(__COUNTER__, name, schema_func) -#define MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_HELPER_ELSEWHERE(Counter, name, schema_func) \ - MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) -#define MS_SIGNAL_OPERATOR_SCHEMA_UNIQ_ELSEWHERE(Counter, name, schema_func) \ - static ONNX_NAMESPACE::OpSchemaRegistry::OpSchemaRegisterOnce( \ - op_schema_register_once##name##Counter) ONNX_UNUSED = \ - schema_func(ONNX_NAMESPACE::OpSchema(#name, __FILE__, __LINE__)) - -void RegisterSignalSchemas(); -} // namespace dml -} // namespace onnxruntime From 1a380b3dc8750df699b30ea12c895c8299f1f63d Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Thu, 16 Jun 2022 00:05:09 +0000 Subject: [PATCH 06/20] simplify get_scalar_value_from_tensor --- onnxruntime/core/providers/cpu/signal/utils.h | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/cpu/signal/utils.h b/onnxruntime/core/providers/cpu/signal/utils.h index c3bc949844d03..8e9e828c533a9 100644 --- a/onnxruntime/core/providers/cpu/signal/utils.h +++ b/onnxruntime/core/providers/cpu/signal/utils.h @@ -11,20 +11,7 @@ namespace signal { template static T get_scalar_value_from_tensor(const Tensor* tensor) { ORT_ENFORCE(tensor->Shape().Size() == 1, "ratio input should have a single value."); - - auto data_type = tensor->GetElementType(); - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - return static_cast(*reinterpret_cast(tensor->DataRaw())); - default: - ORT_THROW("Unsupported input data type of ", data_type); - } + return *tensor->Data(); } } // namespace signal From 492b957ad5d0e5239bc377f8fe843e7b6cb6cade Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Thu, 16 Jun 2022 00:15:19 +0000 Subject: [PATCH 07/20] move compute_exponential to a helper --- onnxruntime/core/providers/cpu/signal/dft.cc | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index c4a5bf680eaaf..ff67a37c2ac33 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -155,6 +155,12 @@ static T compute_angular_velocity(size_t number_of_samples, bool inverse) { return angular_velocity; } +template +static std::complex compute_exponential(size_t index, const T angular_velocity) { + const T angle = static_cast(index) * angular_velocity; + return std::complex(cos(angle), sin(angle)); +} + template static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, size_t X_offset, size_t X_stride, size_t Y_offset, size_t Y_stride, int64_t axis, size_t dft_length, const Tensor* window, @@ -192,8 +198,7 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s V = InlinedVector>(dft_length); // e^(i *2*pi / N * k) for (size_t i = 0; i < dft_length; i++) { size_t bit_reversed_index = bit_reverse(i, significant_bits); - const T angle = static_cast(i) * angular_velocity; - V[bit_reversed_index] = std::complex(cos(angle), sin(angle)); + V[bit_reversed_index] = compute_exponential(i, angular_velocity); } } @@ -270,8 +275,7 @@ static Status dft_naive(const Tensor* X, Tensor* Y, size_t X_offset, size_t X_st out.imag(0); for (size_t j = 0; j < dft_length; j++) { // vectorize over this loop - const T angle = static_cast(i) * static_cast(j) * angular_velocity; - auto exponential = std::complex(cos(angle), sin(angle)); + auto exponential = compute_exponential(i * j, angular_velocity); auto window_element = window_data ? *(window_data + j) : 1; auto x = (j < number_of_samples) ? *(X_data + j * X_stride) : 0; auto element = x * window_element; From 82c0eed07455f475d099acb2fe2e8db3c8164c37 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Thu, 16 Jun 2022 19:30:08 +0000 Subject: [PATCH 08/20] Update OperatorKernels.md --- docs/OperatorKernels.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 36436ec368ffa..61b93c3e5e115 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -43,6 +43,7 @@ Do not modify directly.* |||[9, 13]|**T** = tensor(double), tensor(float)| |||[7, 8]|**T** = tensor(double), tensor(float)| |BitShift|*in* X:**T**
*in* Y:**T**
*out* Z:**T**|11+|**T** = tensor(uint32), tensor(uint64), tensor(uint8)| +|BlackmanWindow||17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)| @@ -69,6 +70,7 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| +|DFT||17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| @@ -125,6 +127,8 @@ Do not modify directly.* |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| |||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| |GridSample|*in* X:**T1**
*in* grid:**T1**
*out* Y:**T2**|16+|**T1** = tensor(float)
**T2** = tensor(float)| +|HammingWindow||17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|HannWindow||17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float)| |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float)| |||[11, 12]|**T** = tensor(float)| @@ -186,6 +190,7 @@ Do not modify directly.* |MeanVarianceNormalization|*in* X:**T**
*out* Y:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float)| |||[9, 12]|**T** = tensor(float)| |||[1, 8]|**T** = tensor(float)| +|MelWeightMatrix||17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(float)
**T3** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[8, 11]|**T** = tensor(double), tensor(float)| @@ -277,6 +282,7 @@ Do not modify directly.* |RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|16+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |||[10, 15]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| +|STFT||17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |Scale|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|16+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| From c736ed6cc71338630a95c8f26301b497b3819359 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Fri, 17 Jun 2022 00:38:43 +0000 Subject: [PATCH 09/20] move tests --- .../test/contrib_ops/signal_ops_test.cc | 207 ------------------ .../providers/cpu/signal/signal_ops_test.cc | 198 +++++++++++++++++ 2 files changed, 198 insertions(+), 207 deletions(-) delete mode 100644 onnxruntime/test/contrib_ops/signal_ops_test.cc create mode 100644 onnxruntime/test/providers/cpu/signal/signal_ops_test.cc diff --git a/onnxruntime/test/contrib_ops/signal_ops_test.cc b/onnxruntime/test/contrib_ops/signal_ops_test.cc deleted file mode 100644 index 3fe4ce75e604e..0000000000000 --- a/onnxruntime/test/contrib_ops/signal_ops_test.cc +++ /dev/null @@ -1,207 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef BUILD_MS_EXPERIMENTAL_OPS - -#include "gtest/gtest.h" -#include "test/providers/provider_test_utils.h" - -namespace onnxruntime { -namespace test { - -static void TestNaiveDFTFloat(bool is_onesided) { - OpTester test("DFT", 1, onnxruntime::kMSExperimentalDomain); - - std::vector shape = {1, 5}; - std::vector output_shape = {1, 5, 2}; - output_shape[1] = is_onesided ? (1 + (shape[1] >> 1)) : shape[1]; - - std::vector input = {1, 2, 3, 4, 5}; - std::vector expected_output = { - 15.000000f, 0.0000000f, - -2.499999f, 3.4409550f, - -2.500000f, 0.8123000f, - -2.499999f, -0.812299f, - -2.500003f, -3.440953f - }; - - if (is_onesided) { - expected_output.resize(6); - } - test.AddInput("input", shape, input); - test.AddAttribute("onesided", static_cast(is_onesided)); - test.AddOutput("output", output_shape, expected_output); - test.Run(); -} - -static void TestRadix2DFTFloat(bool is_onesided) { - OpTester test("DFT", 1, onnxruntime::kMSExperimentalDomain); - - std::vector shape = {1, 8}; - std::vector output_shape = {1, 8, 2}; - output_shape[1] = is_onesided ? (1 + (shape[1] >> 1)) : shape[1]; - - std::vector input = {1, 2, 3, 4, 5, 6, 7, 8}; - std::vector expected_output = { - 36.000f, 0.000f, - -4.000f, 9.65685f, - -4.000f, 4.000f, - -4.000f, 1.65685f, - -4.000f, 0.000f, - -4.000f, -1.65685f, - -4.000f, -4.000f, - -4.000f, -9.65685f - }; - - if (is_onesided) { - expected_output.resize(10); - } - test.AddInput("input", shape, input); - test.AddAttribute("onesided", static_cast(is_onesided)); - test.AddOutput("output", output_shape, expected_output); - test.Run(); -} - -TEST(MLSignalOpTest, DFTFloat) { - TestNaiveDFTFloat(false); - TestNaiveDFTFloat(true); - TestRadix2DFTFloat(false); - TestRadix2DFTFloat(true); -} - -TEST(MLSignalOpTest, IDFTFloat) { - OpTester test("IDFT", 1, onnxruntime::kMSExperimentalDomain); - - std::vector shape = {1, 5, 2}; - std::vector input = - { - 15.000000f, 0.0000000f, - -2.499999f, 3.4409550f, - -2.500000f, 0.8123000f, - -2.499999f, -0.812299f, - -2.500003f, -3.440953f - }; - std::vector expected_output = - { - 1.000f, 0.000f, - 2.000f, 0.000f, - 3.000f, 0.000f, - 4.000f, 0.000f, - 5.000f, 0.000f - }; - - test.AddInput("input", shape, input); - test.AddOutput("output", shape, expected_output); - test.Run(); -} - -TEST(MLSignalOpTest, STFTFloat) { - OpTester test("STFT", 1, onnxruntime::kMSExperimentalDomain); - - std::vector signal(64, 1); - test.AddInput("signal", {1, 64}, signal); - std::vector window(16, 1); - test.AddInput("window", {16}, window); - test.AddInput("frame_length", {}, {16}); - test.AddInput("frame_step", {}, {8}); - - std::vector output_shape = {1, 7, 9, 2}; - std::vector expected_output = - { - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f - }; - test.AddOutput("output", output_shape, expected_output); - test.Run(); -} - -TEST(MLSignalOpTest, HannWindowFloat) { - OpTester test("HannWindow", 1, onnxruntime::kMSExperimentalDomain); - - std::vector scalar_shape = {}; - std::vector output_shape = {32}; - std::vector expected_output = - { - 0.000000f, 0.009607f, 0.038060f, 0.084265f, 0.146447f, 0.222215f, 0.308658f, 0.402455f, - 0.500000f, 0.597545f, 0.691342f, 0.777785f, 0.853553f, 0.915735f, 0.961940f, 0.990393f, - 1.000000f, 0.990393f, 0.961940f, 0.915735f, 0.853553f, 0.777785f, 0.691342f, 0.597545f, - 0.500000f, 0.402455f, 0.308658f, 0.222215f, 0.146447f, 0.084265f, 0.038060f, 0.009607f - }; - - test.AddInput("size", scalar_shape, {32}); - test.AddOutput("output", output_shape, expected_output); - test.Run(); -} - -TEST(MLSignalOpTest, HammingWindowFloat) { - OpTester test("HammingWindow", 1, onnxruntime::kMSExperimentalDomain); - - std::vector scalar_shape = {}; - std::vector output_shape = {32}; - std::vector expected_output = - { - 0.086957f, 0.095728f, 0.121707f, 0.163894f, 0.220669f, 0.289848f, 0.368775f, 0.454415f, - 0.543478f, 0.632541f, 0.718182f, 0.797108f, 0.866288f, 0.923062f, 0.965249f, 0.991228f, - 1.000000f, 0.991228f, 0.965249f, 0.923062f, 0.866288f, 0.797108f, 0.718182f, 0.632541f, - 0.543478f, 0.454415f, 0.368775f, 0.289848f, 0.220669f, 0.163894f, 0.121707f, 0.095728f - }; - - test.AddInput("size", scalar_shape, {32}); - test.AddOutput("output", output_shape, expected_output); - test.Run(); -} - -TEST(MLSignalOpTest, BlackmanWindowFloat) { - OpTester test("BlackmanWindow", 1, onnxruntime::kMSExperimentalDomain); - - std::vector scalar_shape = {}; - std::vector output_shape = {32}; - std::vector expected_output = - { - 0.000000f, 0.003518f, 0.014629f, 0.034880f, 0.066447f, 0.111600f, 0.172090f, 0.248544f, - 0.340000f, 0.443635f, 0.554773f, 0.667170f, 0.773553f, 0.866350f, 0.938508f, 0.984303f, - 1.000000f, 0.984303f, 0.938508f, 0.866350f, 0.773553f, 0.667170f, 0.554773f, 0.443635f, - 0.340000f, 0.248544f, 0.172090f, 0.111600f, 0.066447f, 0.034880f, 0.014629f, 0.003518f - }; - - test.AddInput("size", scalar_shape, {32}); - test.AddOutput("output", output_shape, expected_output); - test.Run(); -} - -TEST(MLSignalOpTest, MelWeightMatrixFloat) { - OpTester test("MelWeightMatrix", 1, onnxruntime::kMSExperimentalDomain); - - std::vector scalar_shape = {}; - std::vector output_shape = {9, 8}; - std::vector expected_output = - { - 1.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 1.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f - }; - - test.AddInput("num_mel_bins", scalar_shape, {8}); - test.AddInput("dft_length", scalar_shape, {16}); - test.AddInput("sample_rate", scalar_shape, {8192}); - test.AddInput("lower_edge_hertz", scalar_shape, {0}); - test.AddInput("upper_edge_hertz", scalar_shape, {8192 / 2.f}); - test.AddOutput("output", output_shape, expected_output); - test.Run(); -} - -} // namespace test -} // namespace onnxruntime - -#endif \ No newline at end of file diff --git a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc new file mode 100644 index 0000000000000..fc679555eaffd --- /dev/null +++ b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc @@ -0,0 +1,198 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void TestNaiveDFTFloat(bool is_onesided) { + OpTester test("DFT", 17); + + std::vector shape = {1, 5, 1}; + std::vector output_shape = {1, 5, 2}; + output_shape[1] = is_onesided ? (1 + (shape[1] >> 1)) : shape[1]; + + std::vector input = {1, 2, 3, 4, 5}; + std::vector expected_output = { + 15.000000f, 0.0000000f, + -2.499999f, 3.4409550f, + -2.500000f, 0.8123000f, + -2.499999f, -0.812299f, + -2.500003f, -3.440953f}; + + if (is_onesided) { + expected_output.resize(6); + } + test.AddInput("input", shape, input); + test.AddAttribute("onesided", static_cast(is_onesided)); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +static void TestRadix2DFTFloat(bool is_onesided) { + OpTester test("DFT", 17); + + std::vector shape = {1, 8, 1}; + std::vector output_shape = {1, 8, 2}; + output_shape[1] = is_onesided ? (1 + (shape[1] >> 1)) : shape[1]; + + std::vector input = {1, 2, 3, 4, 5, 6, 7, 8}; + std::vector expected_output = { + 36.000f, 0.000f, + -4.000f, 9.65685f, + -4.000f, 4.000f, + -4.000f, 1.65685f, + -4.000f, 0.000f, + -4.000f, -1.65685f, + -4.000f, -4.000f, + -4.000f, -9.65685f}; + + if (is_onesided) { + expected_output.resize(10); + } + test.AddInput("input", shape, input); + test.AddAttribute("onesided", static_cast(is_onesided)); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +TEST(SignalOpsTest, DFTFloat_Naive) { + TestNaiveDFTFloat(false); + TestNaiveDFTFloat(true); +} + +TEST(SignalOpsTest, DFTFloat_Radix2) { + TestRadix2DFTFloat(false); + TestRadix2DFTFloat(true); +} + +TEST(SignalOpsTest, DFTFloat_inverse) { + OpTester test("DFT", 17); + + std::vector shape = {1, 5, 2}; + std::vector input = + { + 15.000000f, 0.0000000f, + -2.499999f, 3.4409550f, + -2.500000f, 0.8123000f, + -2.499999f, -0.812299f, + -2.500003f, -3.440953f}; + std::vector expected_output = + { + 1.000f, 0.000f, + 2.000f, 0.000f, + 3.000f, 0.000f, + 4.000f, 0.000f, + 5.000f, 0.000f}; + + test.AddInput("input", shape, input); + test.AddAttribute("inverse", static_cast(true)); + test.AddOutput("output", shape, expected_output); + test.Run(); +} + +TEST(SignalOpsTest, STFTFloat) { + OpTester test("STFT", 17); + + std::vector signal(64, 1); + test.AddInput("signal", {1, 64, 1}, signal); + std::vector window(16, 1); + test.AddInput("window", {16}, window); + test.AddInput("frame_length", {}, {16}); + test.AddInput("frame_step", {}, {8}); + + std::vector output_shape = {1, 7, 9, 2}; + std::vector expected_output = + { + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f}; + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +TEST(SignalOpsTest, HannWindowFloat) { + OpTester test("HannWindow", 17); + + std::vector scalar_shape = {}; + std::vector output_shape = {32}; + std::vector expected_output = + { + 0.000000f, 0.009607f, 0.038060f, 0.084265f, 0.146447f, 0.222215f, 0.308658f, 0.402455f, + 0.500000f, 0.597545f, 0.691342f, 0.777785f, 0.853553f, 0.915735f, 0.961940f, 0.990393f, + 1.000000f, 0.990393f, 0.961940f, 0.915735f, 0.853553f, 0.777785f, 0.691342f, 0.597545f, + 0.500000f, 0.402455f, 0.308658f, 0.222215f, 0.146447f, 0.084265f, 0.038060f, 0.009607f}; + + test.AddInput("size", scalar_shape, {32}); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +TEST(SignalOpsTest, HammingWindowFloat) { + OpTester test("HammingWindow", 17); + + std::vector scalar_shape = {}; + std::vector output_shape = {32}; + std::vector expected_output = + { + 0.086957f, 0.095728f, 0.121707f, 0.163894f, 0.220669f, 0.289848f, 0.368775f, 0.454415f, + 0.543478f, 0.632541f, 0.718182f, 0.797108f, 0.866288f, 0.923062f, 0.965249f, 0.991228f, + 1.000000f, 0.991228f, 0.965249f, 0.923062f, 0.866288f, 0.797108f, 0.718182f, 0.632541f, + 0.543478f, 0.454415f, 0.368775f, 0.289848f, 0.220669f, 0.163894f, 0.121707f, 0.095728f}; + + test.AddInput("size", scalar_shape, {32}); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +TEST(SignalOpsTest, BlackmanWindowFloat) { + OpTester test("BlackmanWindow", 17); + + std::vector scalar_shape = {}; + std::vector output_shape = {32}; + std::vector expected_output = + { + 0.000000f, 0.003518f, 0.014629f, 0.034880f, 0.066447f, 0.111600f, 0.172090f, 0.248544f, + 0.340000f, 0.443635f, 0.554773f, 0.667170f, 0.773553f, 0.866350f, 0.938508f, 0.984303f, + 1.000000f, 0.984303f, 0.938508f, 0.866350f, 0.773553f, 0.667170f, 0.554773f, 0.443635f, + 0.340000f, 0.248544f, 0.172090f, 0.111600f, 0.066447f, 0.034880f, 0.014629f, 0.003518f}; + + test.AddInput("size", scalar_shape, {32}); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +TEST(SignalOpsTest, MelWeightMatrixFloat) { + OpTester test("MelWeightMatrix", 17); + + std::vector scalar_shape = {}; + std::vector output_shape = {9, 8}; + std::vector expected_output = + { + 1.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 1.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + + test.AddInput("num_mel_bins", scalar_shape, {8}); + test.AddInput("dft_length", scalar_shape, {16}); + test.AddInput("sample_rate", scalar_shape, {8192}); + test.AddInput("lower_edge_hertz", scalar_shape, {0}); + test.AddInput("upper_edge_hertz", scalar_shape, {8192 / 2.f}); + test.AddOutput("output", output_shape, expected_output); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime From 0d1765cc83097d4f28e082dee6a7fbfdea4f338a Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Fri, 17 Jun 2022 00:44:59 +0000 Subject: [PATCH 10/20] fix input order --- onnxruntime/test/providers/cpu/signal/signal_ops_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc index fc679555eaffd..c168ef5c12ce7 100644 --- a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc +++ b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc @@ -98,10 +98,10 @@ TEST(SignalOpsTest, STFTFloat) { std::vector signal(64, 1); test.AddInput("signal", {1, 64, 1}, signal); + test.AddInput("frame_step", {}, {8}); std::vector window(16, 1); test.AddInput("window", {16}, window); test.AddInput("frame_length", {}, {16}); - test.AddInput("frame_step", {}, {8}); std::vector output_shape = {1, 7, 9, 2}; std::vector expected_output = From dbb3e2f9528198f4ca631fce0000b19a42248db0 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Fri, 17 Jun 2022 22:41:01 +0000 Subject: [PATCH 11/20] dpnt use a switch statement for bit_reverse --- onnxruntime/core/providers/cpu/signal/dft.cc | 82 ++------------------ 1 file changed, 6 insertions(+), 76 deletions(-) diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index ff67a37c2ac33..9f7ab3a97c138 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -64,85 +64,15 @@ static const unsigned char BitReverseTable256[] = { 0x27, 0xA7, 0x67, 0xE7, 0x17, 0x97, 0x57, 0xD7, 0x37, 0xB7, 0x77, 0xF7, 0x0F, 0x8F, 0x4F, 0xCF, 0x2F, 0xAF, 0x6F, 0xEF, 0x1F, 0x9F, 0x5F, 0xDF, 0x3F, 0xBF, 0x7F, 0xFF}; -template -uint32_t bit_reverse(uint32_t num) { - uint32_t rev = (BitReverseTable256[num & 0xff] << 24) | (BitReverseTable256[(num >> 8) & 0xff] << 16) | - (BitReverseTable256[(num >> 16) & 0xff] << 8) | (BitReverseTable256[(num >> 24) & 0xff]); - return static_cast(((uint64_t)rev) >> (32 - TSignificantBits)); -} - template static inline T bit_reverse(T num, unsigned significant_bits) { - switch (significant_bits) { - case 0: - return static_cast(bit_reverse<0>(static_cast(num))); - case 1: - return static_cast(bit_reverse<1>(static_cast(num))); - case 2: - return static_cast(bit_reverse<2>(static_cast(num))); - case 3: - return static_cast(bit_reverse<3>(static_cast(num))); - case 4: - return static_cast(bit_reverse<4>(static_cast(num))); - case 5: - return static_cast(bit_reverse<5>(static_cast(num))); - case 6: - return static_cast(bit_reverse<6>(static_cast(num))); - case 7: - return static_cast(bit_reverse<7>(static_cast(num))); - case 8: - return static_cast(bit_reverse<8>(static_cast(num))); - case 9: - return static_cast(bit_reverse<9>(static_cast(num))); - case 10: - return static_cast(bit_reverse<10>(static_cast(num))); - case 11: - return static_cast(bit_reverse<11>(static_cast(num))); - case 12: - return static_cast(bit_reverse<12>(static_cast(num))); - case 13: - return static_cast(bit_reverse<13>(static_cast(num))); - case 14: - return static_cast(bit_reverse<14>(static_cast(num))); - case 15: - return static_cast(bit_reverse<15>(static_cast(num))); - case 16: - return static_cast(bit_reverse<16>(static_cast(num))); - case 17: - return static_cast(bit_reverse<17>(static_cast(num))); - case 18: - return static_cast(bit_reverse<18>(static_cast(num))); - case 19: - return static_cast(bit_reverse<19>(static_cast(num))); - case 20: - return static_cast(bit_reverse<20>(static_cast(num))); - case 21: - return static_cast(bit_reverse<21>(static_cast(num))); - case 22: - return static_cast(bit_reverse<22>(static_cast(num))); - case 23: - return static_cast(bit_reverse<23>(static_cast(num))); - case 24: - return static_cast(bit_reverse<24>(static_cast(num))); - case 25: - return static_cast(bit_reverse<25>(static_cast(num))); - case 26: - return static_cast(bit_reverse<26>(static_cast(num))); - case 27: - return static_cast(bit_reverse<27>(static_cast(num))); - case 28: - return static_cast(bit_reverse<28>(static_cast(num))); - case 29: - return static_cast(bit_reverse<29>(static_cast(num))); - case 30: - return static_cast(bit_reverse<30>(static_cast(num))); - case 31: - return static_cast(bit_reverse<31>(static_cast(num))); - case 32: - return static_cast(bit_reverse<32>(static_cast(num))); - default: - ORT_THROW("Unsupported bit size."); + if (significant_bits > 32) { + ORT_THROW("Unsupported bit size."); } + uint32_t num_32 = static_cast(num); + uint32_t rev = (BitReverseTable256[num_32 & 0xff] << 24) | (BitReverseTable256[(num_32 >> 8) & 0xff] << 16) | + (BitReverseTable256[(num_32 >> 16) & 0xff] << 8) | (BitReverseTable256[(num_32 >> 24) & 0xff]); + return static_cast(((uint64_t)rev) >> (32 - significant_bits)); } template From 02fc960941d5a7a2e159546af88d1f3297fb2f26 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Fri, 17 Jun 2022 22:41:30 +0000 Subject: [PATCH 12/20] test that FFT is invertible --- .../providers/cpu/signal/signal_ops_test.cc | 275 ++++++++++-------- 1 file changed, 161 insertions(+), 114 deletions(-) diff --git a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc index c168ef5c12ce7..d2e635bf5afff 100644 --- a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc +++ b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc @@ -1,91 +1,76 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" #include "test/providers/provider_test_utils.h" +#include "test/util/include/test_random_seed.h" + +using std::vector; namespace onnxruntime { namespace test { -static void TestNaiveDFTFloat(bool is_onesided) { - OpTester test("DFT", 17); +static const int kMinOpsetVersion = 17; - std::vector shape = {1, 5, 1}; - std::vector output_shape = {1, 5, 2}; - output_shape[1] = is_onesided ? (1 + (shape[1] >> 1)) : shape[1]; +static void TestNaiveDFTFloat(bool onesided) { + OpTester test("DFT", kMinOpsetVersion); - std::vector input = {1, 2, 3, 4, 5}; - std::vector expected_output = { - 15.000000f, 0.0000000f, - -2.499999f, 3.4409550f, - -2.500000f, 0.8123000f, - -2.499999f, -0.812299f, - -2.500003f, -3.440953f}; + vector shape = {1, 5, 1}; + vector output_shape = {1, 5, 2}; + output_shape[1] = onesided ? (1 + (shape[1] >> 1)) : shape[1]; - if (is_onesided) { + vector input = {1, 2, 3, 4, 5}; + vector expected_output = {15.000000f, 0.0000000f, -2.499999f, 3.4409550f, -2.500000f, + 0.8123000f, -2.499999f, -0.812299f, -2.500003f, -3.440953f}; + + if (onesided) { expected_output.resize(6); } test.AddInput("input", shape, input); - test.AddAttribute("onesided", static_cast(is_onesided)); + test.AddAttribute("onesided", static_cast(onesided)); test.AddOutput("output", output_shape, expected_output); test.Run(); } -static void TestRadix2DFTFloat(bool is_onesided) { - OpTester test("DFT", 17); - - std::vector shape = {1, 8, 1}; - std::vector output_shape = {1, 8, 2}; - output_shape[1] = is_onesided ? (1 + (shape[1] >> 1)) : shape[1]; - - std::vector input = {1, 2, 3, 4, 5, 6, 7, 8}; - std::vector expected_output = { - 36.000f, 0.000f, - -4.000f, 9.65685f, - -4.000f, 4.000f, - -4.000f, 1.65685f, - -4.000f, 0.000f, - -4.000f, -1.65685f, - -4.000f, -4.000f, - -4.000f, -9.65685f}; - - if (is_onesided) { +static void TestRadix2DFTFloat(bool onesided) { + OpTester test("DFT", kMinOpsetVersion); + + vector shape = {1, 8, 1}; + vector output_shape = {1, 8, 2}; + output_shape[1] = onesided ? (1 + (shape[1] >> 1)) : shape[1]; + + vector input = {1, 2, 3, 4, 5, 6, 7, 8}; + vector expected_output = {36.000f, 0.000f, -4.000f, 9.65685f, -4.000f, 4.000f, -4.000f, 1.65685f, + -4.000f, 0.000f, -4.000f, -1.65685f, -4.000f, -4.000f, -4.000f, -9.65685f}; + + if (onesided) { expected_output.resize(10); } test.AddInput("input", shape, input); - test.AddAttribute("onesided", static_cast(is_onesided)); + test.AddAttribute("onesided", static_cast(onesided)); test.AddOutput("output", output_shape, expected_output); test.Run(); } -TEST(SignalOpsTest, DFTFloat_Naive) { - TestNaiveDFTFloat(false); - TestNaiveDFTFloat(true); -} +TEST(SignalOpsTest, DFTFloat_naive) { TestNaiveDFTFloat(false); } -TEST(SignalOpsTest, DFTFloat_Radix2) { - TestRadix2DFTFloat(false); - TestRadix2DFTFloat(true); -} +TEST(SignalOpsTest, DFTFloat_naive_onesided) { TestNaiveDFTFloat(true); } + +TEST(SignalOpsTest, DFTFloat_radix2) { TestRadix2DFTFloat(false); } + +TEST(SignalOpsTest, DFTFloat_radix2_onesided) { TestRadix2DFTFloat(true); } TEST(SignalOpsTest, DFTFloat_inverse) { - OpTester test("DFT", 17); - - std::vector shape = {1, 5, 2}; - std::vector input = - { - 15.000000f, 0.0000000f, - -2.499999f, 3.4409550f, - -2.500000f, 0.8123000f, - -2.499999f, -0.812299f, - -2.500003f, -3.440953f}; - std::vector expected_output = - { - 1.000f, 0.000f, - 2.000f, 0.000f, - 3.000f, 0.000f, - 4.000f, 0.000f, - 5.000f, 0.000f}; + OpTester test("DFT", kMinOpsetVersion); + + vector shape = {1, 5, 2}; + vector input = {15.000000f, 0.0000000f, -2.499999f, 3.4409550f, -2.500000f, + 0.8123000f, -2.499999f, -0.812299f, -2.500003f, -3.440953f}; + vector expected_output = {1.000f, 0.000f, 2.000f, 0.000f, 3.000f, 0.000f, 4.000f, 0.000f, 5.000f, 0.000f}; test.AddInput("input", shape, input); test.AddAttribute("inverse", static_cast(true)); @@ -93,41 +78,107 @@ TEST(SignalOpsTest, DFTFloat_inverse) { test.Run(); } +// Tests that FFT(FFT(x), inverse=true) == x +static void TestDFTInvertible(bool complex) { + // TODO: test dft_length + class DFTInvertibleTester : public OpTester { + public: + DFTInvertibleTester(int64_t axis) : OpTester("DFT", kMinOpsetVersion), axis_(axis) {} + + protected: + void AddNodes(Graph& graph, vector& graph_inputs, vector& graph_outputs, + vector>& add_attribute_funcs) override { + // Create an intermediate output + vector intermediate_outputs = graph_outputs; + ONNX_NAMESPACE::TypeProto type_info = *intermediate_outputs[0]->TypeAsProto(); // copy + NodeArg& dft_output = graph.GetOrCreateNodeArg("dft_output", &type_info); + intermediate_outputs[0] = &dft_output; + + // call base implementation to add the DFT node. + OpTester::AddNodes(graph, graph_inputs, intermediate_outputs, add_attribute_funcs); + OpTester::AddAttribute("axis", axis_); + + Node& inverse = graph.AddNode("inverse", "DFT", "inverse", intermediate_outputs, graph_outputs); + inverse.AddAttribute("inverse", static_cast(true)); + inverse.AddAttribute("axis", axis_); + } + + private: + int64_t axis_; + }; + + RandomValueGenerator random(GetTestRandomSeed()); + // TODO(garymm, smk2007): Add tests for different dft_length values. + const int64_t num_batches = 2; + for (int64_t axis = 1; axis < 2; axis += 1) { + for (int64_t signal_dim1 = 1; signal_dim1 <= 4; signal_dim1 += 1) { + for (int64_t signal_dim2 = 1; signal_dim2 <= 4; signal_dim2 += 1) { + DFTInvertibleTester test(axis); + vector input_shape{num_batches, signal_dim1, signal_dim2, 1 + complex}; + vector input_data = random.Uniform(input_shape, -100.f, 100.f); + test.AddInput("input", input_shape, input_data); + + vector output_shape(input_shape); + vector* output_data_p; + vector output_data; + if (complex) { + output_data_p = &input_data; + } else { // real -> (real, imaginary) with imaginary == 0. + output_shape[3] = 2; + output_data.resize(input_data.size() * 2, 0); + for (size_t i = 0; i < input_data.size(); i += 1) { + output_data[i * 2] = input_data[i]; + } + output_data_p = &output_data; + } + test.AddOutput("output", output_shape, *output_data_p); + test.Run(); + } + } + } +} + +TEST(SignalOpsTest, DFT_invertible_real) { TestDFTInvertible(false); } + +TEST(SignalOpsTest, DFT_invertible_complex) { TestDFTInvertible(true); } + TEST(SignalOpsTest, STFTFloat) { - OpTester test("STFT", 17); + OpTester test("STFT", kMinOpsetVersion); - std::vector signal(64, 1); + vector signal(64, 1); test.AddInput("signal", {1, 64, 1}, signal); test.AddInput("frame_step", {}, {8}); - std::vector window(16, 1); + vector window(16, 1); test.AddInput("window", {16}, window); test.AddInput("frame_length", {}, {16}); - std::vector output_shape = {1, 7, 9, 2}; - std::vector expected_output = - { - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, - 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f}; + vector output_shape = {1, 7, 9, 2}; + vector expected_output = { + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 16.000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f, + 0.0000f, 0.000f, 0.0000f, 0.000f, 0.0000f, 0.000f}; test.AddOutput("output", output_shape, expected_output); test.Run(); } TEST(SignalOpsTest, HannWindowFloat) { - OpTester test("HannWindow", 17); + OpTester test("HannWindow", kMinOpsetVersion); - std::vector scalar_shape = {}; - std::vector output_shape = {32}; - std::vector expected_output = - { - 0.000000f, 0.009607f, 0.038060f, 0.084265f, 0.146447f, 0.222215f, 0.308658f, 0.402455f, - 0.500000f, 0.597545f, 0.691342f, 0.777785f, 0.853553f, 0.915735f, 0.961940f, 0.990393f, - 1.000000f, 0.990393f, 0.961940f, 0.915735f, 0.853553f, 0.777785f, 0.691342f, 0.597545f, - 0.500000f, 0.402455f, 0.308658f, 0.222215f, 0.146447f, 0.084265f, 0.038060f, 0.009607f}; + vector scalar_shape = {}; + vector output_shape = {32}; + vector expected_output = {0.000000f, 0.009607f, 0.038060f, 0.084265f, 0.146447f, 0.222215f, 0.308658f, + 0.402455f, 0.500000f, 0.597545f, 0.691342f, 0.777785f, 0.853553f, 0.915735f, + 0.961940f, 0.990393f, 1.000000f, 0.990393f, 0.961940f, 0.915735f, 0.853553f, + 0.777785f, 0.691342f, 0.597545f, 0.500000f, 0.402455f, 0.308658f, 0.222215f, + 0.146447f, 0.084265f, 0.038060f, 0.009607f}; test.AddInput("size", scalar_shape, {32}); test.AddOutput("output", output_shape, expected_output); @@ -135,16 +186,15 @@ TEST(SignalOpsTest, HannWindowFloat) { } TEST(SignalOpsTest, HammingWindowFloat) { - OpTester test("HammingWindow", 17); + OpTester test("HammingWindow", kMinOpsetVersion); - std::vector scalar_shape = {}; - std::vector output_shape = {32}; - std::vector expected_output = - { - 0.086957f, 0.095728f, 0.121707f, 0.163894f, 0.220669f, 0.289848f, 0.368775f, 0.454415f, - 0.543478f, 0.632541f, 0.718182f, 0.797108f, 0.866288f, 0.923062f, 0.965249f, 0.991228f, - 1.000000f, 0.991228f, 0.965249f, 0.923062f, 0.866288f, 0.797108f, 0.718182f, 0.632541f, - 0.543478f, 0.454415f, 0.368775f, 0.289848f, 0.220669f, 0.163894f, 0.121707f, 0.095728f}; + vector scalar_shape = {}; + vector output_shape = {32}; + vector expected_output = // + {0.086957f, 0.095728f, 0.121707f, 0.163894f, 0.220669f, 0.289848f, 0.368775f, 0.454415f, + 0.543478f, 0.632541f, 0.718182f, 0.797108f, 0.866288f, 0.923062f, 0.965249f, 0.991228f, + 1.000000f, 0.991228f, 0.965249f, 0.923062f, 0.866288f, 0.797108f, 0.718182f, 0.632541f, + 0.543478f, 0.454415f, 0.368775f, 0.289848f, 0.220669f, 0.163894f, 0.121707f, 0.095728f}; test.AddInput("size", scalar_shape, {32}); test.AddOutput("output", output_shape, expected_output); @@ -152,16 +202,15 @@ TEST(SignalOpsTest, HammingWindowFloat) { } TEST(SignalOpsTest, BlackmanWindowFloat) { - OpTester test("BlackmanWindow", 17); + OpTester test("BlackmanWindow", kMinOpsetVersion); - std::vector scalar_shape = {}; - std::vector output_shape = {32}; - std::vector expected_output = - { - 0.000000f, 0.003518f, 0.014629f, 0.034880f, 0.066447f, 0.111600f, 0.172090f, 0.248544f, - 0.340000f, 0.443635f, 0.554773f, 0.667170f, 0.773553f, 0.866350f, 0.938508f, 0.984303f, - 1.000000f, 0.984303f, 0.938508f, 0.866350f, 0.773553f, 0.667170f, 0.554773f, 0.443635f, - 0.340000f, 0.248544f, 0.172090f, 0.111600f, 0.066447f, 0.034880f, 0.014629f, 0.003518f}; + vector scalar_shape = {}; + vector output_shape = {32}; + vector expected_output = // + {0.000000f, 0.003518f, 0.014629f, 0.034880f, 0.066447f, 0.111600f, 0.172090f, 0.248544f, + 0.340000f, 0.443635f, 0.554773f, 0.667170f, 0.773553f, 0.866350f, 0.938508f, 0.984303f, + 1.000000f, 0.984303f, 0.938508f, 0.866350f, 0.773553f, 0.667170f, 0.554773f, 0.443635f, + 0.340000f, 0.248544f, 0.172090f, 0.111600f, 0.066447f, 0.034880f, 0.014629f, 0.003518f}; test.AddInput("size", scalar_shape, {32}); test.AddOutput("output", output_shape, expected_output); @@ -169,21 +218,19 @@ TEST(SignalOpsTest, BlackmanWindowFloat) { } TEST(SignalOpsTest, MelWeightMatrixFloat) { - OpTester test("MelWeightMatrix", 17); - - std::vector scalar_shape = {}; - std::vector output_shape = {9, 8}; - std::vector expected_output = - { - 1.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 1.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 1.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, - 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; + OpTester test("MelWeightMatrix", kMinOpsetVersion); + + vector scalar_shape = {}; + vector output_shape = {9, 8}; + vector expected_output = { + 1.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 1.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 1.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, + 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f}; test.AddInput("num_mel_bins", scalar_shape, {8}); test.AddInput("dft_length", scalar_shape, {16}); From 529add6b1710019441c8a06de9715934ef201063 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Wed, 22 Jun 2022 00:35:04 +0000 Subject: [PATCH 13/20] Enable tests --- .../test/testdata/onnx_backend_test_series_filters.jsonc | 1 - .../test/testdata/onnx_backend_test_series_overrides.jsonc | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 091c573d38ae0..95b3516312f16 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -4,7 +4,6 @@ // // Tests that are failing temporarily and should be fixed "current_failing_tests": [ - "^test_(blackmanwindow|dft|hammingwindow|hannwindow|melweightmatrix|stft).*", // https://github.com/microsoft/onnxruntime/pull/11778 "^test_adagrad", "^test_adagrad_multiple", "^test_batchnorm_epsilon_old", diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc index 921e491b63510..8b2ec0246809e 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc @@ -5,7 +5,9 @@ // Val: float, max absolute difference between expected and actual. "atol_overrides": { "test_dft": 1e-4, - "test_dft_axis": 1e-4 + "test_dft_axis": 1e-4, + "test_stft": 1e-4, + "test_stft_with_window": 1e-4 }, // Key: str, the name of the test as defined by ONNX without any device suffix. // Val: float, max relative difference between expected and actual. From 1ecaa634010d1c48220a90fd4719c56032058c46 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Wed, 22 Jun 2022 17:25:48 +0000 Subject: [PATCH 14/20] Fix get_scalar_value_from_tensor --- onnxruntime/core/providers/cpu/signal/utils.h | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/signal/utils.h b/onnxruntime/core/providers/cpu/signal/utils.h index 8e9e828c533a9..a5ff5df6e5d48 100644 --- a/onnxruntime/core/providers/cpu/signal/utils.h +++ b/onnxruntime/core/providers/cpu/signal/utils.h @@ -11,7 +11,19 @@ namespace signal { template static T get_scalar_value_from_tensor(const Tensor* tensor) { ORT_ENFORCE(tensor->Shape().Size() == 1, "ratio input should have a single value."); - return *tensor->Data(); + const auto data_type = tensor->GetElementType(); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return static_cast(*tensor->Data()); + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + return static_cast(*tensor->Data()); + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + return static_cast(*tensor->Data()); + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + return static_cast(*tensor->Data()); + default: + ORT_THROW("Unsupported input data type of ", data_type); + } } } // namespace signal From 01a1cccea490d8939efda3b1660ca55f9db5b6a5 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Wed, 22 Jun 2022 18:15:10 +0000 Subject: [PATCH 15/20] Support periodic attribute --- .../core/providers/cpu/signal/window_functions.cc | 15 ++++++++------- .../core/providers/cpu/signal/window_functions.h | 12 ++++++++++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/cpu/signal/window_functions.cc b/onnxruntime/core/providers/cpu/signal/window_functions.cc index 552f930a46e8f..4ddd76641a6eb 100644 --- a/onnxruntime/core/providers/cpu/signal/window_functions.cc +++ b/onnxruntime/core/providers/cpu/signal/window_functions.cc @@ -48,13 +48,14 @@ ONNX_CPU_OPERATOR_KERNEL(MelWeightMatrix, 17, template struct CosineSumWindow { - Status operator()(Tensor* Y, size_t size, float a0, float a1, float a2) { + Status operator()(Tensor* Y, size_t size, float a0, float a1, float a2, bool is_periodic) { auto* Y_data = reinterpret_cast(Y->MutableDataRaw()); // Calculate the radians to increment per sample constexpr double pi = 3.14159265; constexpr double tau = 2 * pi; - const double angular_increment = tau / size; + const size_t denominator = is_periodic ? size : size - 1; + const double angular_increment = tau / denominator; for (size_t i = 0; i < size; i++) { auto a2_component = a2 == 0 ? 0 : (a2 * cos(2 * angular_increment * i)); @@ -68,7 +69,7 @@ struct CosineSumWindow { }; static Status create_cosine_sum_window(OpKernelContext* ctx, onnx::TensorProto_DataType output_datatype, float a0, - float a1, float a2) { + float a1, float a2, bool is_periodic) { // Get the size of the window auto size = signal::get_scalar_value_from_tensor(ctx->Input(0)); @@ -78,7 +79,7 @@ static Status create_cosine_sum_window(OpKernelContext* ctx, onnx::TensorProto_D utils::MLTypeCallDispatcher dispatcher(output_datatype); - return dispatcher.InvokeRet(Y, size, a0, a1, a2); + return dispatcher.InvokeRet(Y, size, a0, a1, a2, is_periodic); } Status HannWindow::Compute(OpKernelContext* ctx) const { @@ -87,7 +88,7 @@ Status HannWindow::Compute(OpKernelContext* ctx) const { float a0 = .5f; float a1 = a0; float a2 = 0; - return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); + return create_cosine_sum_window(ctx, data_type_, a0, a1, a2, is_periodic_); } Status HammingWindow::Compute(OpKernelContext* ctx) const { @@ -96,7 +97,7 @@ Status HammingWindow::Compute(OpKernelContext* ctx) const { float a0 = 25.f / 46.f; float a1 = 1 - a0; float a2 = 0; - return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); + return create_cosine_sum_window(ctx, data_type_, a0, a1, a2, is_periodic_); } Status BlackmanWindow::Compute(OpKernelContext* ctx) const { @@ -106,7 +107,7 @@ Status BlackmanWindow::Compute(OpKernelContext* ctx) const { float a2 = alpha / 2.f; float a0 = .5f - a2; float a1 = .5f; - return create_cosine_sum_window(ctx, data_type_, a0, a1, a2); + return create_cosine_sum_window(ctx, data_type_, a0, a1, a2, is_periodic_); } static inline double hz_to_mel_scale(double hz) { return 2595 * std::log10(1 + hz / 700); } diff --git a/onnxruntime/core/providers/cpu/signal/window_functions.h b/onnxruntime/core/providers/cpu/signal/window_functions.h index 052c3ac43a16a..994149b5ced06 100644 --- a/onnxruntime/core/providers/cpu/signal/window_functions.h +++ b/onnxruntime/core/providers/cpu/signal/window_functions.h @@ -20,22 +20,34 @@ class VariableOutputDataTypeBase : public OpKernel { class HannWindow final : public VariableOutputDataTypeBase { public: explicit HannWindow(const OpKernelInfo& info) : VariableOutputDataTypeBase(info) { + is_periodic_ = static_cast(info.GetAttrOrDefault("periodic", 1)); } Status Compute(OpKernelContext* ctx) const override; + + private: + bool is_periodic_ = true; }; class HammingWindow final : public VariableOutputDataTypeBase { public: explicit HammingWindow(const OpKernelInfo& info) : VariableOutputDataTypeBase(info) { + is_periodic_ = static_cast(info.GetAttrOrDefault("periodic", 1)); } Status Compute(OpKernelContext* ctx) const override; + + private: + bool is_periodic_ = true; }; class BlackmanWindow final : public VariableOutputDataTypeBase { public: explicit BlackmanWindow(const OpKernelInfo& info) : VariableOutputDataTypeBase(info) { + is_periodic_ = static_cast(info.GetAttrOrDefault("periodic", 1)); } Status Compute(OpKernelContext* ctx) const override; + + private: + bool is_periodic_ = true; }; class MelWeightMatrix final : public VariableOutputDataTypeBase { From 283ed5ce600ead5364f4b25892e3304b3abef3bb Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Thu, 23 Jun 2022 01:24:04 +0000 Subject: [PATCH 16/20] simplify test --- onnxruntime/test/providers/cpu/signal/signal_ops_test.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc index d2e635bf5afff..2db126b140607 100644 --- a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc +++ b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc @@ -44,7 +44,7 @@ static void TestRadix2DFTFloat(bool onesided) { output_shape[1] = onesided ? (1 + (shape[1] >> 1)) : shape[1]; vector input = {1, 2, 3, 4, 5, 6, 7, 8}; - vector expected_output = {36.000f, 0.000f, -4.000f, 9.65685f, -4.000f, 4.000f, -4.000f, 1.65685f, + vector expected_output = {36.000f, 0.000f, -4.000f, 9.65685f, -4.000f, 4.000f, -4.000f, 1.65685f, -4.000f, 0.000f, -4.000f, -1.65685f, -4.000f, -4.000f, -4.000f, -9.65685f}; if (onesided) { @@ -89,10 +89,7 @@ static void TestDFTInvertible(bool complex) { void AddNodes(Graph& graph, vector& graph_inputs, vector& graph_outputs, vector>& add_attribute_funcs) override { // Create an intermediate output - vector intermediate_outputs = graph_outputs; - ONNX_NAMESPACE::TypeProto type_info = *intermediate_outputs[0]->TypeAsProto(); // copy - NodeArg& dft_output = graph.GetOrCreateNodeArg("dft_output", &type_info); - intermediate_outputs[0] = &dft_output; + vector intermediate_outputs{&graph.GetOrCreateNodeArg("dft_output", graph_outputs[0]->TypeAsProto())}; // call base implementation to add the DFT node. OpTester::AddNodes(graph, graph_inputs, intermediate_outputs, add_attribute_funcs); From 7f3308887312f4b3cc6aabeb4592e6f74b9bd39e Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Thu, 23 Jun 2022 01:24:37 +0000 Subject: [PATCH 17/20] check for out of bounds access --- onnxruntime/core/providers/cpu/signal/dft.cc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index 9f7ab3a97c138..c877b227630fa 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -113,7 +113,8 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s std::complex* Y_data; if (is_onesided) { if (temp_output.size() != dft_length) { - temp_output = InlinedVector>(dft_length); + temp_output.clear(); + temp_output.resize(dft_length); } Y_data = temp_output.data(); } else { @@ -125,7 +126,8 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s // Create vandermonde matrix V ordered with the bit-reversed permutation if (V.size() != dft_length) { - V = InlinedVector>(dft_length); // e^(i *2*pi / N * k) + V.clear(); + V.resize(dft_length); for (size_t i = 0; i < dft_length; i++) { size_t bit_reversed_index = bit_reverse(i, significant_bits); V[bit_reversed_index] = compute_exponential(i, angular_velocity); @@ -153,6 +155,14 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s auto odd_index = k + j + midpoint; std::complex* even = (Y_data + even_index * Y_data_stride); std::complex* odd = (Y_data + odd_index * Y_data_stride); + if (is_onesided) { + if (even > &temp_output[temp_output.size() - 1]) { + ORT_THROW("even is out of range"); + } + if (odd > &temp_output[temp_output.size() - 1]) { + ORT_THROW("odd is out of range"); + } + } std::complex first = *even + (V[first_idx] * *odd); std::complex second = *even + (V[second_idx] * *odd); *even = first; From e8088442305e7cb31e64a01ba0ecbb68c99fed3a Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Thu, 23 Jun 2022 19:44:15 +0000 Subject: [PATCH 18/20] update OperatorKernels.md --- docs/OperatorKernels.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 61b93c3e5e115..102dce0737296 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -43,7 +43,7 @@ Do not modify directly.* |||[9, 13]|**T** = tensor(double), tensor(float)| |||[7, 8]|**T** = tensor(double), tensor(float)| |BitShift|*in* X:**T**
*in* Y:**T**
*out* Z:**T**|11+|**T** = tensor(uint32), tensor(uint64), tensor(uint8)| -|BlackmanWindow||17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|BlackmanWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)| @@ -70,7 +70,7 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| -|DFT||17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|DFT|*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| @@ -127,8 +127,8 @@ Do not modify directly.* |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| |||[12, 15]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T1** = tensor(bool)| |GridSample|*in* X:**T1**
*in* grid:**T1**
*out* Y:**T2**|16+|**T1** = tensor(float)
**T2** = tensor(float)| -|HammingWindow||17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|HannWindow||17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|HammingWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|HannWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float)| |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float)| |||[11, 12]|**T** = tensor(float)| @@ -190,7 +190,7 @@ Do not modify directly.* |MeanVarianceNormalization|*in* X:**T**
*out* Y:**T**

or

*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float)| |||[9, 12]|**T** = tensor(float)| |||[1, 8]|**T** = tensor(float)| -|MelWeightMatrix||17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(float)
**T3** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|MelWeightMatrix|*in* num_mel_bins:**T1**
*in* dft_length:**T1**
*in* sample_rate:**T1**
*in* lower_edge_hertz:**T2**
*in* upper_edge_hertz:**T2**
*out* output:**T3**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(float)
**T3** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Min|*in* data_0:**T**
*out* min:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||12|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[8, 11]|**T** = tensor(double), tensor(float)| @@ -282,7 +282,7 @@ Do not modify directly.* |RoiAlign|*in* X:**T1**
*in* rois:**T1**
*in* batch_indices:**T2**
*out* Y:**T1**|16+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |||[10, 15]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |Round|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)| -|STFT||17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|STFT|*in* signal:**T1**
*in* frame_step:**T2**
*in* window:**T1**
*in* frame_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |Scale|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |ScaledTanh|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |Scan|*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**

or

*in* sequence_lens:**I**
*in* initial_state_and_scan_inputs:**V**
*out* final_state_and_scan_outputs:**V**|16+|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| From e27222fad90c09b573abbf4bcf5f5caec02db05c Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Fri, 24 Jun 2022 00:11:03 +0000 Subject: [PATCH 19/20] remove bounds check, it did not fire --- onnxruntime/core/providers/cpu/signal/dft.cc | 8 -------- 1 file changed, 8 deletions(-) diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index c877b227630fa..5652d82167d13 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -155,14 +155,6 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s auto odd_index = k + j + midpoint; std::complex* even = (Y_data + even_index * Y_data_stride); std::complex* odd = (Y_data + odd_index * Y_data_stride); - if (is_onesided) { - if (even > &temp_output[temp_output.size() - 1]) { - ORT_THROW("even is out of range"); - } - if (odd > &temp_output[temp_output.size() - 1]) { - ORT_THROW("odd is out of range"); - } - } std::complex first = *even + (V[first_idx] * *odd); std::complex second = *even + (V[second_idx] * *odd); *even = first; From e6ec47eba3d9ab8ba343ebd09a0bf6dfc3b80033 Mon Sep 17 00:00:00 2001 From: Gary Miguel Date: Fri, 24 Jun 2022 19:38:08 +0000 Subject: [PATCH 20/20] fix onesided destination overflow --- onnxruntime/core/providers/cpu/signal/dft.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index 5652d82167d13..97d7e19a7c4b1 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -113,7 +113,6 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s std::complex* Y_data; if (is_onesided) { if (temp_output.size() != dft_length) { - temp_output.clear(); temp_output.resize(dft_length); } Y_data = temp_output.data(); @@ -126,7 +125,6 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s // Create vandermonde matrix V ordered with the bit-reversed permutation if (V.size() != dft_length) { - V.clear(); V.resize(dft_length); for (size_t i = 0; i < dft_length; i++) { size_t bit_reversed_index = bit_reverse(i, significant_bits); @@ -172,8 +170,9 @@ static Status fft_radix2(OpKernelContext* /*ctx*/, const Tensor* X, Tensor* Y, s } if (is_onesided) { + const size_t output_size = (dft_length >> 1) + 1; auto destination = reinterpret_cast*>(Y->MutableDataRaw()) + Y_offset; - for (size_t i = 0; i < dft_length; i++) { + for (size_t i = 0; i < output_size; i++) { *(destination + Y_stride * i) = *(Y_data + i * Y_data_stride); } }