From 9194c36b771f30cb508e022b4b76597419bbac7d Mon Sep 17 00:00:00 2001 From: hariharans29 Date: Thu, 22 Aug 2019 19:42:18 -0700 Subject: [PATCH 1/7] Support bilinear mode with actual 2D inputs in Resize and upsample --- .../core/providers/cpu/tensor/upsample.cc | 66 ++++++++++++------- .../core/providers/cpu/tensor/upsample.h | 7 +- .../providers/cpu/tensor/resize_op_test.cc | 48 +++++++++++++- .../providers/cpu/tensor/upsample_op_test.cc | 30 ++++++++- 4 files changed, 119 insertions(+), 32 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index 95605dbef4a68..f9a964bdcae77 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -3,6 +3,7 @@ #include "core/providers/cpu/tensor/upsample.h" #include +#include using namespace onnxruntime::common; using namespace std; @@ -61,14 +62,18 @@ Status UpsampleNearest(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales) { + const vector& scales, + const bool is_resize) { if (!input || !output) - return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value is nullptr"); + return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input/output value is nullptr" : + "Upsample: input/output value is nullptr"); if (input_shape.NumDimensions() != output_shape.NumDimensions()) - return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value's dimension mismatch"); + return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input/output value's dimension mismatch" : + "Upsample: input/output value's dimension mismatch"); if (input_shape.NumDimensions() == 0) { return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "Upsample: input shape needs to be at least a single dimension."); + is_resize ? "Resize: input shape needs to be at least a single dimension" : + "Upsample: input shape needs to be at least a single dimension."); } int64_t n_dim = static_cast(input_shape.NumDimensions()); @@ -194,9 +199,11 @@ Status upsampleLiner(const T* input, const TensorShape& output_shape, const vector& scales) { if (!input || !output) - return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value is nullptr"); + return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input / output value is nullptr" : + "Upsample: input / output value is nullptr", ); if (input_shape.NumDimensions() != output_shape.NumDimensions()) - return Status(ONNXRUNTIME, FAIL, "Upsample: input/output value's dimension mismatch"); + return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input/output value's dimension mismatch" : + "Upsample: input/output value's dimension mismatch"); auto n_dim = input_shape.NumDimensions(); for (size_t i = 0, size = output_shape.Size(); i < size; i++) { std::vector val1; @@ -242,6 +249,11 @@ Status upsampleLiner(const T* input, return Status::OK(); } +// The following method supports a 4-D input in 'Linear mode' +// that amounts to 'Bilinear' Upsampling/Resizing in the sense that it assumes +// the scale values for the outermost 2 dimensions are 1. +// This is the common use-case where the 4-D input (batched multi-channel images) +// is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale] template void upsampleBilinear( int64_t batch_size, @@ -327,9 +339,10 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vector& dims = X->Shape().GetDims(); - if (dims.size() != scales.size()) { - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Upsample: input tensor's dimension does not match the scales."); - } + if (dims.size() != scales.size()) + return Status(ONNXRUNTIME, INVALID_ARGUMENT, + is_resize ? "Resize: input tensor's dimension does not match the scales." : + "Upsample: input tensor's dimension does not match the scales."); bool no_scale = true; std::vector Y_dims; @@ -348,26 +361,33 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vector(X->template Data(), Y->template MutableData(), X->Shape(), Y->Shape(), scales); + return UpsampleNearest(X->template Data(), Y->template MutableData(), X->Shape(), Y->Shape(), scales, is_resize); case UpsampleMode::LINEAR: { - //What's the correct behavior of linear mode is not clear right now, - //Only support bilinear with 4D tensor to keep consistent with previous behavior - if (dims.size() != 4) - return Status(ONNXRUNTIME, FAIL, "Upsample: linear mode upsample only support 4-D tensor with NCHW layout"); + //The correct behavior of 'linear' mode for an N-D input is not clear right now, + //so only support 'bilinear' with 2-D or 4-D input tensor with outermost 2 scales as 1 in the 4-D case + if (dims.size() != 2 && dims.size() != 4) { + std::ostringstream oss; + oss << "'Linear' mode only support 2-D inputs ('Bilinear') or 4-D inputs " + "with the corresponding outermost 2 scale values being 1 in the "; + oss << is_resize ? "Resize operator" : "Upsample operator"; + return Status(ONNXRUNTIME, FAIL, oss.str()); + } - const int64_t batch_size = dims[0]; - const int64_t num_channels = dims[1]; - const int64_t input_height = dims[2]; - const int64_t input_width = dims[3]; + bool is_2D = dims.size() == 2; + const int64_t batch_size = is_2D ? 1 : dims[0]; + const int64_t num_channels = is_2D ? 1 : dims[1]; + const int64_t input_height = is_2D ? dims[0] : dims[2]; + const int64_t input_width = is_2D ? dims[1] : dims[3]; AllocatorPtr alloc; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&alloc)); upsampleBilinear(batch_size, num_channels, input_height, input_width, - scales[2], scales[3], X->template Data(), Y->template MutableData(), alloc); + is_2D ? scales[0] : scales[2], is_2D ? scales[1] : scales[3], + X->template Data(), Y->template MutableData(), alloc); return Status::OK(); } default: - return Status(ONNXRUNTIME, FAIL, "Upsample: unexpected mode"); + return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: unexpected mode" : "Upsample: unexpected mode"); } } @@ -380,9 +400,9 @@ Status Upsample::Compute(OpKernelContext* context) const { const auto* scales = context->Input(1); ORT_ENFORCE(scales != nullptr); int64_t scales_size = scales->Shape().Size(); - std::vector scales_arrary(scales_size); - ParseScalesData(scales, scales_arrary); - return BaseCompute(context, scales_arrary); + std::vector scales_array(scales_size); + ParseScalesData(scales, scales_array); + return BaseCompute(context, scales_array); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.h b/onnxruntime/core/providers/cpu/tensor/upsample.h index 5c57295af5195..97b41e0915d89 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample.h @@ -72,9 +72,10 @@ class UpsampleBase { } if (UpsampleMode::LINEAR == mode) { - ORT_ENFORCE(scales.size() == 4, "Upsample: linear mode upsample only support bilinear with 4 dimension."); - ORT_ENFORCE(((scales[0] == 1) && (scales[1] == 1)), - "Upsample: linear mode upsample only support bilinear, the first 2 scales should be 1."); + ORT_ENFORCE(scales.size() == 2 || (scales.size() == 4 && scales[0] == 1 && scales[1] == 1), + "'Linear' mode only support 2-D inputs ('Bilinear') or 4-D inputs " + "with the corresponding outermost 2 scale values being 1 in the ", + is_resize ? "Resize operator" : "Upsample operator"); } } diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 0611aa2501937..3e790eacc8c4b 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -7,7 +7,7 @@ namespace onnxruntime { namespace test { -TEST(ResizeOpTest, ResizeOpLineartDownSampleTest) { +TEST(ResizeOpTest, ResizeOpLineartDownSampleTest_4DBilinear) { OpTester test("Resize", 10); std::vector scales{1.0f, 1.0f, 0.6f, 0.6f}; @@ -27,7 +27,26 @@ TEST(ResizeOpTest, ResizeOpLineartDownSampleTest) { test.Run(); } -TEST(ResizeOpTest, ResizeOpLineartUpSampleTest) { +TEST(ResizeOpTest, ResizeOpLineartDownSampleTest_2DBilinear) { + OpTester test("Resize", 10); + std::vector scales{0.6f, 0.6f}; + + test.AddAttribute("mode", "linear"); + + const int64_t H = 2, W = 4; + std::vector X = { + 1.0f, 2.0f, 3.0f, 4.0f, + 5.0f, 6.0f, 7.0f, 8.0f}; + + test.AddInput("X", {H, W}, X); + test.AddInput("scales", {2}, scales); + + std::vector Y = {1.0f, 2.66666651f}; + + test.AddOutput("Y", {(int64_t)(H * scales[0]), (int64_t)(W * scales[1])}, Y); + test.Run(); +} +TEST(ResizeOpTest, ResizeOpLineartUpSampleTest_4DBilinear) { OpTester test("Resize", 10); std::vector scales{1.0f, 1.0f, 2.0f, 4.0f}; test.AddAttribute("mode", "linear"); @@ -57,7 +76,30 @@ TEST(ResizeOpTest, ResizeOpLineartUpSampleTest) { test.Run(); } -TEST(ResizeOpTest, ResizeOpLineartNoScaleTest) { +TEST(ResizeOpTest, ResizeOpLineartUpSampleTest_2DBilinear) { + OpTester test("Resize", 10); + std::vector scales{2.0f, 4.0f}; + test.AddAttribute("mode", "linear"); + + const int64_t H = 2, W = 2; + std::vector X = {1.0f, 3.0f, + 4.0f, 8.0f}; + + test.AddInput("X", {H, W}, X); + test.AddInput("scales", {2}, scales); + + std::vector Y = { + 1.0f, 1.5f, 2.0f, 2.5f, 3.0f, 3.0f, 3.0f, 3.0f, + 2.5f, 3.25f, 4.0f, 4.75f, 5.5f, 5.5f, 5.5f, 5.5f, + 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 8.0f, 8.0f, 8.0f, + 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 8.0f, 8.0f, 8.0f + }; + + test.AddOutput("Y", {(int64_t)(H * scales[0]), (int64_t)(W * scales[1])}, Y); + test.Run(); +} + +TEST(ResizeOpTest, ResizeOpLineartScalesNoOpTest) { OpTester test("Resize", 10); std::vector scales{1.0f, 1.0f, 1.0f, 1.0f}; test.AddAttribute("mode", "linear"); diff --git a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc index 68924aa60b3b0..e7a67bc12d682 100644 --- a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc @@ -264,7 +264,7 @@ TEST(UpsampleOpTest, UpsampleOpNearest2XTest_int32) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); //TensorRT: nvinfer1::query::Ports&): Assertion `!formats.empty()' failed } -TEST(UpsampleOpTest, UpsampleOpBilinearTest) { +TEST(UpsampleOpTest, UpsampleOp4DBilinearTest) { OpTester test("Upsample"); std::vector scales{1.0f, 1.0f, 2.0f, 4.0f}; @@ -295,7 +295,31 @@ TEST(UpsampleOpTest, UpsampleOpBilinearTest) { test.Run(); } -TEST(UpsampleOpTest, UpsampleOpBilinearTest_NoScale) { +TEST(UpsampleOpTest, UpsampleOp2DBilinearTest) { + OpTester test("Upsample"); + + std::vector scales{2.0f, 4.0f}; + test.AddAttribute("mode", "linear"); + test.AddAttribute("scales", scales); + + const int64_t H = 2, W = 2; + std::vector X = {1.0f, 3.0f, + 3.0f, 5.0f}; + + test.AddInput("X", {H, W}, X); + + std::vector Y = { + 1.0f, 1.5f, 2.0f, 2.5f, 3.0f, 3.0f, 3.0f, 3.0f, + 2.0f, 2.5f, 3.0f, 3.5f, 4.0f, 4.0f, 4.0f, 4.0f, + 3.0f, 3.5f, 4.0f, 4.5f, 5.0f, 5.0f, 5.0f, 5.0f, + 3.0f, 3.5f, 4.0f, 4.5f, 5.0f, 5.0f, 5.0f, 5.0f + }; + + test.AddOutput("Y", {(int64_t)(H * scales[0]), (int64_t)(W * scales[1])}, Y); + test.Run(); +} + +TEST(UpsampleOpTest, UpsampleOp4DBilinearTest_ScalesNoOp) { OpTester test("Upsample"); std::vector scales{1.0f, 1.0f, 1.0f, 1.0f}; @@ -321,7 +345,7 @@ TEST(UpsampleOpTest, UpsampleOpBilinearTest_NoScale) { test.Run(); } -TEST(UpsampleOpTest, UpsampleOpBilinearTest_int32) { +TEST(UpsampleOpTest, UpsampleOp4DBilinearTest_int32) { OpTester test("Upsample"); std::vector scales{1.0f, 1.0f, 2.0f, 4.0f}; From 01bfd4f6e2a310f2c71f9bc2558c61d6afa94ece Mon Sep 17 00:00:00 2001 From: hariharans29 Date: Fri, 23 Aug 2019 12:21:26 -0700 Subject: [PATCH 2/7] Fix build break --- onnxruntime/core/providers/cpu/tensor/upsample.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index f9a964bdcae77..8234705ec00f9 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -369,7 +369,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vector Date: Fri, 23 Aug 2019 15:19:34 -0700 Subject: [PATCH 3/7] Fix build break --- onnxruntime/core/providers/cpu/tensor/upsample.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index 8234705ec00f9..e7072993b8972 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -197,10 +197,11 @@ Status upsampleLiner(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales) { + const vector& scales, + const bool is_resize) { if (!input || !output) return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input / output value is nullptr" : - "Upsample: input / output value is nullptr", ); + "Upsample: input / output value is nullptr"); if (input_shape.NumDimensions() != output_shape.NumDimensions()) return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input/output value's dimension mismatch" : "Upsample: input/output value's dimension mismatch"); From 39e9bd46df7e9f5a332a5f386ebd1102a3fb49f0 Mon Sep 17 00:00:00 2001 From: hariharans29 Date: Fri, 23 Aug 2019 15:29:23 -0700 Subject: [PATCH 4/7] Add test --- onnxruntime/test/providers/cpu/tensor/resize_op_test.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 3e790eacc8c4b..1d2f03e39ce1d 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -46,6 +46,7 @@ TEST(ResizeOpTest, ResizeOpLineartDownSampleTest_2DBilinear) { test.AddOutput("Y", {(int64_t)(H * scales[0]), (int64_t)(W * scales[1])}, Y); test.Run(); } + TEST(ResizeOpTest, ResizeOpLineartUpSampleTest_4DBilinear) { OpTester test("Resize", 10); std::vector scales{1.0f, 1.0f, 2.0f, 4.0f}; From 126805c21ce073ae4a9577d04e66634f58bf41de Mon Sep 17 00:00:00 2001 From: hariharans29 Date: Fri, 23 Aug 2019 17:41:56 -0700 Subject: [PATCH 5/7] CUDA changes --- .../core/providers/cuda/tensor/resize_impl.cu | 71 ++++++++++++++++++- .../core/providers/cuda/tensor/upsample.cc | 28 ++++---- .../providers/cuda/tensor/upsample_impl.cu | 68 +++++++++++++++++- 3 files changed, 148 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index f8df8a9689f02..55d7fcaf01f49 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -29,8 +29,13 @@ __global__ void _ResizeNearestKernel(const size_t rank, output_data[id] = input_data[input_index]; } +// The following method supports a 4-D input in 'Linear mode' +// that amounts to 'Bilinear' Upsampling/Resizing in the sense that it assumes +// the scale values for the outermost 2 dimensions are 1. +// This is the common use-case where the 4-D input (batched multi-channel images) +// is usually of shape [N, C, H, W] and the scales are [1.0, 1.0, height_scale, width_scale] template -__global__ void _ResizeBilinearKernel(const int64_t input_dim2, +__global__ void _ResizeBilinear4DInputKernel(const int64_t input_dim2, const int64_t* input_pitches, const fast_divmod* output_div_pitches, const float* scales, @@ -90,6 +95,62 @@ __global__ void _ResizeBilinearKernel(const int64_t input_dim2, x11 * static_cast(y_offset_0 * x_offset_0); } +// The following method supports a 2-D input in 'Linear mode' +template +__global__ void _ResizeBilinear2DInputKernel(const int64_t input_dim0, + const int64_t* input_pitches, + const fast_divmod* output_div_pitches, + const float* scales, + const T* input_data, + T* output_data, + const size_t N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + CUDA_LONG input_index = 0; + + int mod; + int index_of_dim0, index_of_dim1; + output_div_pitches[0].divmod(id, index_of_dim0, mod); + index_of_dim1 = mod; + int index_of_input_dim0, index_of_input_dim1; + float x_offset_0, y_offset_0, x_offset_1, y_offset_1; + index_of_input_dim0 = static_cast(index_of_dim0 / scales[0]); + index_of_input_dim1 = static_cast(index_of_dim1 / scales[1]); + input_index = index_of_input_dim0 * input_pitches[0] + index_of_input_dim1; + + T x00 = input_data[input_index]; + T x10, x01, x11; + + bool end_of_dim0 = false, end_of_dim1 = false; + if (index_of_input_dim0 == (input_dim0 - 1)) { + // It's the end in dimension 0 + x01 = x00; + end_of_dim0 = true; + } else { + x01 = input_data[input_index + input_pitches[0]]; + } + + if (index_of_input_dim1 == (input_pitches[0] - 1)) { + // It's the end in dimension 1 + x10 = x00; + x11 = x01; + end_of_dim1 = true; + } else { + x10 = input_data[input_index + 1]; + x11 = end_of_dim0 ? x10 : input_data[input_index + input_pitches[0] + 1]; + } + + y_offset_0 = end_of_dim0 ? 0.5f : index_of_dim0 / scales[0] - index_of_input_dim0; + y_offset_1 = 1.0f - y_offset_0; + x_offset_0 = end_of_dim1 ? 0.5f : index_of_dim1 / scales[1] - index_of_input_dim1; + x_offset_1 = 1.0f - x_offset_0; + + output_data[id] = + x00 * static_cast(y_offset_1 * x_offset_1) + + x01 * static_cast(y_offset_0 * x_offset_1) + + x10 * static_cast(y_offset_1 * x_offset_0) + + x11 * static_cast(y_offset_0 * x_offset_0); +} + template void ResizeImpl(const onnxruntime::UpsampleMode upsample_mode, const size_t rank, @@ -105,8 +166,12 @@ void ResizeImpl(const onnxruntime::UpsampleMode upsample_mode, _ResizeNearestKernel<<>>( rank, input_pitches, output_div_pitches, scales_vals, input_data, output_data, N); - } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode) { - _ResizeBilinearKernel<<>>( + } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 4) { + _ResizeBilinear4DInputKernel<<>>( + input_dim2, input_pitches, output_div_pitches, scales_vals, + input_data, output_data, N); + } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 2) { + _ResizeBilinear2DInputKernel<<>>( input_dim2, input_pitches, output_div_pitches, scales_vals, input_data, output_data, N); } diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index 88248983d70ae..3a9eb36c22f41 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -38,10 +38,21 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vector& X_dims = X->Shape().GetDims(); auto rank = X_dims.size(); if (rank == 0) - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Upsample: input tensor cannot be scalar."); + return Status(ONNXRUNTIME, INVALID_ARGUMENT, + is_resize ? "Resize: input tensor cannot be scalar." : "Upsample: input tensor cannot be scalar."); if (rank != scales.size()) - return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Upsample: input tensor's dimension does not match the scales."); + return Status(ONNXRUNTIME, INVALID_ARGUMENT, + is_resize ? "Resize: input tensor's dimension does not match the scales." : + "Upsample: input tensor's dimension does not match the scales."); + + if (UpsampleMode::LINEAR == mode_ && rank != 4 && rank != 2) { + std::ostringstream oss; + oss << "'Linear' mode only support 2-D inputs ('Bilinear') or 4-D inputs " + "with the corresponding outermost 2 scale values being 1 in the "; + oss << (is_resize ? "Resize operator" : "Upsample operator"); + return Status(ONNXRUNTIME, FAIL, oss.str()); + } std::vector Y_dims; for (std::size_t i = 0; i < rank; i++) { @@ -69,21 +80,12 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vectorShape().Size(); - if (UpsampleMode::LINEAR == mode_) { - if (rank != 4) - if (is_resize) { - return Status(ONNXRUNTIME, FAIL, "Resize: linear mode only supports 4-D tensor with NCHW layout"); - } else { - return Status(ONNXRUNTIME, FAIL, "Upsample: linear mode only supports 4-D tensor with NCHW layout"); - } - } - if (is_resize) { CudaAsyncBuffer scales_vals(this, device_id, scales); scales_vals.CopyToGpu(); ResizeImpl(mode_, rank, - (UpsampleMode::LINEAR == mode_) ? X_dims[2] : 0, + (UpsampleMode::LINEAR == mode_) ? (rank == 2 ? X_dims[0] : X_dims[2]) : 0, input_strides.GpuPtr(), output_div_pitches.GpuPtr(), scales_vals.GpuPtr(), @@ -101,7 +103,7 @@ Status Upsample::BaseCompute(OpKernelContext* context, const std::vector -__global__ void _UpampleBilinearKernel(const int64_t input_dim2, +__global__ void _UpampleBilinear4DInputKernel(const int64_t input_dim2, const int64_t* input_pitches, const fast_divmod* output_div_pitches, const fast_divmod* scales_div, @@ -90,6 +95,59 @@ __global__ void _UpampleBilinearKernel(const int64_t input_dim2, output_data[id] = y0 + static_cast(x_offset_T * (y1 - y0) / scales_div3_T); } +// The following method supports a 2-D input in 'Linear mode' +template +__global__ void _UpampleBilinear2DInputKernel(const int64_t input_dim0, + const int64_t* input_pitches, + const fast_divmod* output_div_pitches, + const fast_divmod* scales_div, + const T* input_data, + T* output_data, + const size_t N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + CUDA_LONG input_index = 0; + + int mod; + int index_of_dim0, index_of_dim1; + output_div_pitches[0].divmod(id, index_of_dim0, mod); + index_of_dim1 = mod; + int index_of_input_dim0, index_of_input_dim1, x_offset, y_offset; + scales_div[0].divmod(index_of_dim0, index_of_input_dim0, y_offset); + scales_div[1].divmod(index_of_dim1, index_of_input_dim1, x_offset); + + input_index = index_of_input_dim0 * input_pitches[0] + index_of_input_dim1; + + T x00 = input_data[input_index]; + T x10, x01, x11; + + bool end_of_dim0 = false; + if (index_of_input_dim0 == (input_dim0 - 1)) { + // It's the end in dimension 0 + x01 = x00; + end_of_dim0 = true; + } else { + x01 = input_data[input_index + input_pitches[0]]; + } + + if (index_of_input_dim1 == (input_pitches[0] - 1)) { + // It's the end in dimension 1 + x10 = x00; + x11 = x01; + } else { + x10 = input_data[input_index + 1]; + x11 = end_of_dim0 ? x10 : input_data[input_index + input_pitches[0] + 1]; + } + + T y_offset_T = static_cast(y_offset); + T x_offset_T = static_cast(x_offset); + T scales_div0_T = static_cast(scales_div[0].d_); + T scales_div1_T = static_cast(scales_div[1].d_); + T y0 = x00 + static_cast(y_offset_T * (x01 - x00) / scales_div0_T); + T y1 = x10 + static_cast(y_offset_T * (x11 - x10) / scales_div0_T); + + output_data[id] = y0 + static_cast(x_offset_T * (y1 - y0) / scales_div1_T); +} + template void UpampleImpl(const onnxruntime::UpsampleMode upsample_mode, const size_t rank, @@ -105,8 +163,12 @@ void UpampleImpl(const onnxruntime::UpsampleMode upsample_mode, _UpampleNearestKernel<<>>( rank, input_pitches, output_div_pitches, scales_div, input_data, output_data, N); - } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode) { - _UpampleBilinearKernel<<>>( + } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 4) { + _UpampleBilinear4DInputKernel<<>>( + input_dim2, input_pitches, output_div_pitches, scales_div, + input_data, output_data, N); + } else if (onnxruntime::UpsampleMode::LINEAR == upsample_mode && rank == 2) { + _UpampleBilinear2DInputKernel<<>>( input_dim2, input_pitches, output_div_pitches, scales_div, input_data, output_data, N); } From 1bca648b729c9c17a3f24b467fcabeac409c66a0 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 28 Aug 2019 22:44:37 -0700 Subject: [PATCH 6/7] Resolve PR comments --- onnxruntime/core/providers/cpu/tensor/upsample.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index e7072993b8972..2bf7710105be8 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -63,7 +63,7 @@ Status UpsampleNearest(const T* input, const TensorShape& input_shape, const TensorShape& output_shape, const vector& scales, - const bool is_resize) { + bool is_resize) { if (!input || !output) return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input/output value is nullptr" : "Upsample: input/output value is nullptr"); From 83862084c1595d5cf9b564c69e78815483b62ef7 Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 28 Aug 2019 23:43:51 -0700 Subject: [PATCH 7/7] Resolve comments --- onnxruntime/core/providers/cpu/tensor/upsample.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index 2bf7710105be8..3dcfcb47a353b 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -198,7 +198,7 @@ Status upsampleLiner(const T* input, const TensorShape& input_shape, const TensorShape& output_shape, const vector& scales, - const bool is_resize) { + bool is_resize) { if (!input || !output) return Status(ONNXRUNTIME, FAIL, is_resize ? "Resize: input / output value is nullptr" : "Upsample: input / output value is nullptr");