diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 08b6a31938111..4f28baeafaf22 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -232,6 +232,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_float_float, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int32_t_float, OneHot); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_float_int64_t, OneHot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t_float_int32_t, OneHot); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t_float_float, OneHot); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh); @@ -508,6 +510,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/tensor/onehot.cc b/onnxruntime/core/providers/cpu/tensor/onehot.cc index 1dfbaaf37640f..9d9b1cff7470c 100644 --- a/onnxruntime/core/providers/cpu/tensor/onehot.cc +++ b/onnxruntime/core/providers/cpu/tensor/onehot.cc @@ -46,6 +46,8 @@ REG_ONE_HOT_OP(float, int64_t, int64_t); REG_ONE_HOT_OP(int64_t, string, int64_t); REG_ONE_HOT_OP(float, string, int64_t); REG_ONE_HOT_OP(int64_t, float, int64_t); +REG_ONE_HOT_OP(int32_t, float, int32_t); +REG_ONE_HOT_OP(int32_t, float, float); REG_ONE_HOT_OP(float, float, float); // added this to satisfy onnx model tests REG_ONE_HOT_OP(int64_t, int32_t, float); // added this to satisfy onnx model tests diff --git a/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc b/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc index c816c7f7b6661..63cebfa1e7a50 100644 --- a/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc @@ -65,6 +65,34 @@ TEST(OneHotOpTest, DefaultAxis_int64_float_int64 /*indices, output, depth*/) { test.Run(); } +TEST(OneHotOpTest, DefaultAxis_int32_float_float /*indices, output, depth*/) { + OpTester test("OneHot", 9); + test.AddInput("indices", {2, 3}, {1, 9, 8, 2, 4, 6}); + test.AddInput("depth", {1}, {10.0f}); + test.AddInput("values", {2}, {0.0f, 1.0f}); + test.AddOutput("output", {2, 3, 10}, {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f}); + test.Run(); +} + +TEST(OneHotOpTest, DefaultAxis_int32_float_int32 /*indices, output, depth*/) { + OpTester test("OneHot", 9); + test.AddInput("indices", {2, 3}, {1, 9, 8, 2, 4, 6}); + test.AddInput("depth", {1}, {10}); + test.AddInput("values", {2}, {0.0f, 1.0f}); + test.AddOutput("output", {2, 3, 10}, {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f}); + test.Run(); +} + TEST(OneHotOpTest, Axis_0) { OpTester test("OneHot", 9); int64_t axis = 0;