Skip to content

Commit

Permalink
Add more type support for OneHot op (#1565)
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharans29 authored Aug 7, 2019
1 parent 9e926fe commit 9a34089
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 0 deletions.
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_float_float, OneHot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int32_t_float, OneHot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_float_int64_t, OneHot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t_float_int32_t, OneHot);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t_float_float, OneHot);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh);
Expand Down Expand Up @@ -508,6 +510,8 @@ void RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, float_float_float, OneHot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_int32_t_float, OneHot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int64_t_float_int64_t, OneHot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t_float_int32_t, OneHot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, int32_t_float_float, OneHot)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, MaxUnpool)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Sinh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, Cosh)>,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/tensor/onehot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ REG_ONE_HOT_OP(float, int64_t, int64_t);
REG_ONE_HOT_OP(int64_t, string, int64_t);
REG_ONE_HOT_OP(float, string, int64_t);
REG_ONE_HOT_OP(int64_t, float, int64_t);
REG_ONE_HOT_OP(int32_t, float, int32_t);
REG_ONE_HOT_OP(int32_t, float, float);
REG_ONE_HOT_OP(float, float, float); // added this to satisfy onnx model tests
REG_ONE_HOT_OP(int64_t, int32_t, float); // added this to satisfy onnx model tests

Expand Down
28 changes: 28 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/onehot_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,34 @@ TEST(OneHotOpTest, DefaultAxis_int64_float_int64 /*indices, output, depth*/) {
test.Run();
}

TEST(OneHotOpTest, DefaultAxis_int32_float_float /*indices, output, depth*/) {
OpTester test("OneHot", 9);
test.AddInput<int32_t>("indices", {2, 3}, {1, 9, 8, 2, 4, 6});
test.AddInput<float>("depth", {1}, {10.0f});
test.AddInput<float>("values", {2}, {0.0f, 1.0f});
test.AddOutput<float>("output", {2, 3, 10}, {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f});
test.Run();
}

TEST(OneHotOpTest, DefaultAxis_int32_float_int32 /*indices, output, depth*/) {
OpTester test("OneHot", 9);
test.AddInput<int32_t>("indices", {2, 3}, {1, 9, 8, 2, 4, 6});
test.AddInput<int32_t>("depth", {1}, {10});
test.AddInput<float>("values", {2}, {0.0f, 1.0f});
test.AddOutput<float>("output", {2, 3, 10}, {0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f,
0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 0.0f});
test.Run();
}

TEST(OneHotOpTest, Axis_0) {
OpTester test("OneHot", 9);
int64_t axis = 0;
Expand Down

0 comments on commit 9a34089

Please sign in to comment.