From 44f7024ec5e17add62b0b621a25587bf3fe18409 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Mon, 10 Jan 2022 09:52:07 +0000 Subject: [PATCH 1/4] support 5d for nearest --- paddle/fluid/operators/interpolate_v2_op.cc | 12 +-- paddle/fluid/operators/interpolate_v2_op.cu | 92 +++++++++++----- paddle/fluid/operators/interpolate_v2_op.h | 101 +++++++++++------- .../unittests/test_nearest_interp_v2_op.py | 9 +- python/paddle/nn/functional/common.py | 17 ++- python/paddle/nn/layer/common.py | 3 +- 6 files changed, 159 insertions(+), 75 deletions(-) diff --git a/paddle/fluid/operators/interpolate_v2_op.cc b/paddle/fluid/operators/interpolate_v2_op.cc index de276cfa31cb52..7783303785998e 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cc +++ b/paddle/fluid/operators/interpolate_v2_op.cc @@ -249,12 +249,12 @@ static void Interpolate3DInferShapeCheck(framework::InferShapeContext* ctx) { auto dim_x = ctx->GetInputDim("X"); auto interp_method = ctx->Attrs().Get("interp_method"); - PADDLE_ENFORCE_EQ( - "trilinear", interp_method, - platform::errors::InvalidArgument( - "Interpolation method can only be \"trilinear\" when Input(X) " - "dimension is 5, but got method = %s .", - interp_method)); + PADDLE_ENFORCE("nearest" == interp_method || "trilinear" == interp_method, + platform::errors::InvalidArgument( + "Interpolation method can only be \"trilinear\" or " + "\"nearest\" when Input(X) " + "dimension is 5, but got method = %s .", + interp_method)); const DataLayout data_layout = framework::StringToDataLayout( ctx->Attrs().Get("data_layout")); diff --git a/paddle/fluid/operators/interpolate_v2_op.cu b/paddle/fluid/operators/interpolate_v2_op.cu index bc1ab704aafe3a..0a0839f6bc017c 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cu +++ b/paddle/fluid/operators/interpolate_v2_op.cu @@ -25,31 +25,40 @@ using DataLayout = framework::DataLayout; template __global__ void KeNearestNeighborInterpFw( - const T* in, const size_t in_img_h, const size_t in_img_w, - const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, - const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const float ratio_h, const float ratio_w, + const T* in, const size_t in_img_d, const size_t in_img_h, + const size_t in_img_w, const size_t input_h, const size_t input_w, T* out, + const size_t out_img_d, const size_t out_img_h, const size_t out_img_w, + const size_t output_h, const size_t output_w, const size_t num_channels, + const float ratio_d, const float ratio_h, const float ratio_w, const bool align_corners, const DataLayout data_layout) { - int nthreads = output_h * output_w; + int nthreads = output_h * output_w; // ncdhw int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (; tid < nthreads; tid += stride) { - int out_id_h = tid / output_w; - int out_id_w = tid % output_w; + int out_id_h = tid / output_w; // n + int out_id_w = tid % output_w; // cdhw + int in_img_size = input_w / num_channels; int out_img_size = output_w / num_channels; - int channel_id, out_img_idy, out_img_idx; + int channel_id, out_img_idt, out_img_idy, out_img_idx; if (data_layout == DataLayout::kNCHW) { channel_id = out_id_w / out_img_size; - out_img_idy = (out_id_w % out_img_size) / out_img_w; + out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w; + out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h; out_img_idx = tid % out_img_w; } else { - out_img_idy = out_id_w / (out_img_w * num_channels); + out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels); + out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) / + (out_img_w * num_channels); out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; channel_id = tid % num_channels; } + int in_img_idt = (align_corners) + ? static_cast(ratio_d * out_img_idt + 0.5) + : static_cast(ratio_d * out_img_idt); + int in_img_idy = (align_corners) ? static_cast(ratio_h * out_img_idy + 0.5) : static_cast(ratio_h * out_img_idy); @@ -59,9 +68,12 @@ __global__ void KeNearestNeighborInterpFw( if (data_layout == DataLayout::kNCHW) { out[tid] = in[out_id_h * input_w + channel_id * in_img_size + - in_img_idy * in_img_w + in_img_idx]; + in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w + + in_img_idx]; } else { - out[tid] = in[out_id_h * input_w + in_img_idy * in_img_w * num_channels + + out[tid] = in[out_id_h * input_w + + in_img_idt * in_img_h * in_img_w * num_channels + + in_img_idy * in_img_w * num_channels + in_img_idx * num_channels + channel_id]; } } @@ -69,10 +81,11 @@ __global__ void KeNearestNeighborInterpFw( template __global__ void KeNearestNeighborInterpBw( - T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, - const size_t input_w, const T* out, const size_t out_img_h, - const size_t out_img_w, const size_t output_h, const size_t output_w, - const size_t num_channels, const float ratio_h, const float ratio_w, + T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, const T* out, + const size_t out_img_d, const size_t out_img_h, const size_t out_img_w, + const size_t output_h, const size_t output_w, const size_t num_channels, + const float ratio_d, const float ratio_h, const float ratio_w, const bool align_corners, const DataLayout data_layout) { int nthreads = output_h * output_w; int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -83,17 +96,23 @@ __global__ void KeNearestNeighborInterpBw( int in_img_size = input_w / num_channels; int out_img_size = output_w / num_channels; - int channel_id, out_img_idy, out_img_idx; + int channel_id, out_img_idt, out_img_idy, out_img_idx; if (data_layout == DataLayout::kNCHW) { channel_id = out_id_w / out_img_size; - out_img_idy = (out_id_w % out_img_size) / out_img_w; + out_img_idt = (out_id_w % out_img_size) / out_img_h / out_img_w; + out_img_idy = ((out_id_w % out_img_size) / out_img_w) % out_img_h; out_img_idx = tid % out_img_w; } else { - out_img_idy = out_id_w / (out_img_w * num_channels); + out_img_idt = out_id_w / (out_img_h * out_img_w * num_channels); + out_img_idy = out_id_w % (out_img_h * out_img_w * num_channels) / + (out_img_w * num_channels); out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; channel_id = tid % num_channels; } + int in_img_idt = (align_corners) + ? static_cast(ratio_d * out_img_idt + 0.5) + : static_cast(ratio_d * out_img_idt); int in_img_idy = (align_corners) ? static_cast(ratio_h * out_img_idy + 0.5) : static_cast(ratio_h * out_img_idy); @@ -104,9 +123,12 @@ __global__ void KeNearestNeighborInterpBw( T* in_pos; if (data_layout == DataLayout::kNCHW) { in_pos = &in[out_id_h * input_w + channel_id * in_img_size + - in_img_idy * in_img_w + in_img_idx]; + in_img_idt * in_img_h * in_img_w + in_img_idy * in_img_w + + in_img_idx]; } else { - in_pos = &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels + + in_pos = &in[out_id_h * input_w + + in_img_idt * in_img_h * in_img_w * num_channels + + in_img_idy * in_img_w * num_channels + in_img_idx * num_channels + channel_id]; } const T out_pos = out[out_id_h * output_w + out_id_w]; @@ -1180,11 +1202,14 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("nearest" == interp_method) { + float ratio_d = 1.f; + int out_d = static_cast(1); KeNearestNeighborInterpFw< T><<>>( - input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, - out_chw, c, ratio_h, ratio_w, align_corners, data_layout); + input_data, in_d, in_h, in_w, n, in_chw, output_data, out_d, out_h, + out_w, n, out_chw, c, ratio_d, ratio_h, ratio_w, align_corners, + data_layout); } else if ("bilinear" == interp_method) { dim3 thread_num = config.thread_per_block; #ifdef WITH_NV_JETSON @@ -1376,6 +1401,13 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx, input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, align_mode, data_layout); + } else if ("nearest" == interp_method) { + KeNearestNeighborInterpFw< + T><<>>( + input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h, + out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, + data_layout); } } @@ -1602,11 +1634,14 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("nearest" == interp_method) { + float ratio_d = 1.f; + int out_d = static_cast(1); KeNearestNeighborInterpBw< T><<>>( - input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, - n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); + input_grad_data, in_d, in_h, in_w, n, in_chw, output_grad_data, out_d, + out_h, out_w, n, out_chw, c, ratio_d, ratio_h, ratio_w, align_corners, + data_layout); } else if ("bilinear" == interp_method) { const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0; bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false; @@ -1801,6 +1836,13 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx, input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d, out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, align_mode, data_layout); + } else if ("nearest" == interp_method) { + KeNearestNeighborInterpBw< + T><<>>( + input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d, + out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, + data_layout); } } diff --git a/paddle/fluid/operators/interpolate_v2_op.h b/paddle/fluid/operators/interpolate_v2_op.h index 8daf440f60e5f6..ebc45114d9eb34 100644 --- a/paddle/fluid/operators/interpolate_v2_op.h +++ b/paddle/fluid/operators/interpolate_v2_op.h @@ -93,27 +93,32 @@ inline void ExtractNCDWH(const framework::DDim& dims, template static void NearestNeighborInterpolate(const Tensor& input, Tensor* output, - const float ratio_h, const float ratio_w, - const int n, const int c, + const float ratio_d, const float ratio_h, + const float ratio_w, const int n, + const int c, const int out_d, const int out_h, const int out_w, const bool align_corners, const DataLayout& data_layout) { - auto input_t = EigenTensor::From(input); - auto output_t = EigenTensor::From(*output); - for (int k = 0; k < out_h; k++) { // loop for images - int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) - : static_cast(ratio_h * k); + auto input_t = EigenTensor::From(input); + auto output_t = EigenTensor::From(*output); + for (int d = 0; d < out_d; d++) { // loop for images + int in_d = (align_corners) ? static_cast(ratio_d * d + 0.5) + : static_cast(ratio_d * d); + for (int k = 0; k < out_h; k++) { + int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) + : static_cast(ratio_h * k); - for (int l = 0; l < out_w; l++) { - int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) - : static_cast(ratio_w * l); + for (int l = 0; l < out_w; l++) { + int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) + : static_cast(ratio_w * l); - for (int i = 0; i < n; i++) { // loop for batches - for (int j = 0; j < c; j++) { // loop for channels - if (data_layout == DataLayout::kNCHW) { - output_t(i, j, k, l) = input_t(i, j, in_k, in_l); - } else { - output_t(i, k, l, j) = input_t(i, in_k, in_l, j); + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + if (data_layout == DataLayout::kNCHW) { + output_t(i, j, d, k, l) = input_t(i, j, in_d, in_k, in_l); + } else { // NDHWC + output_t(i, d, k, l, j) = input_t(i, in_d, in_k, in_l, j); + } } } } @@ -557,26 +562,33 @@ static void BicubicInterpolation(const Tensor& input, Tensor* output, template static void NearestNeighborInterpolateGrad( - const Tensor& output_grad, Tensor* input_grad, const float ratio_h, - const float ratio_w, const int n, const int c, const int out_h, - const int out_w, const bool align_corners, const DataLayout data_layout) { - auto input_grad_t = EigenTensor::From(*input_grad); - auto output_grad_t = EigenTensor::From(output_grad); + const Tensor& output_grad, Tensor* input_grad, const float ratio_d, + const float ratio_h, const float ratio_w, const int n, const int c, + const int out_d, const int out_h, const int out_w, const bool align_corners, + const DataLayout data_layout) { + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); - for (int k = 0; k < out_h; k++) { // loop for images - int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) - : static_cast(ratio_h * k); + for (int d = 0; d < out_d; d++) { + int in_d = (align_corners) ? static_cast(ratio_d * d + 0.5) + : static_cast(ratio_d * d); + for (int k = 0; k < out_h; k++) { // loop for images + int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) + : static_cast(ratio_h * k); - for (int l = 0; l < out_w; l++) { - int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) - : static_cast(ratio_w * l); + for (int l = 0; l < out_w; l++) { + int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) + : static_cast(ratio_w * l); - for (int i = 0; i < n; i++) { // loop for batches - for (int j = 0; j < c; j++) { // loop for channels - if (data_layout == DataLayout::kNCHW) { - input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l); - } else { - input_grad_t(i, in_k, in_l, j) += output_grad_t(i, k, l, j); + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + if (data_layout == DataLayout::kNCHW) { + input_grad_t(i, j, in_d, in_k, in_l) += + output_grad_t(i, j, d, k, l); + } else { + input_grad_t(i, in_d, in_k, in_l, j) += + output_grad_t(i, d, k, l, j); + } } } } @@ -978,8 +990,11 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, out_h, out_w, align_corners, align_mode, data_layout); } else if ("nearest" == interp_method) { - NearestNeighborInterpolate(input, output, ratio_h, ratio_w, n, c, out_h, - out_w, align_corners, data_layout); + float ratio_d = 1.f; + int out_d = static_cast(1); + NearestNeighborInterpolate(input, output, ratio_d, ratio_h, ratio_w, n, + c, out_d, out_h, out_w, align_corners, + data_layout); } else if ("bicubic" == interp_method) { BicubicInterpolation(input, output, ratio_h, ratio_w, in_h, in_w, n, c, out_h, out_w, align_corners, data_layout); @@ -1137,6 +1152,10 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, TrilinearInterpolation(input, output, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n, c, out_d, out_h, out_w, align_corners, align_mode, data_layout); + } else if ("nearest" == interp_method) { + NearestNeighborInterpolate(input, output, ratio_d, ratio_h, ratio_w, n, + c, out_d, out_h, out_w, align_corners, + data_layout); } } @@ -1338,9 +1357,11 @@ static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx, in_h, in_w, n, c, out_h, out_w, align_corners, align_mode, data_layout); } else if ("nearest" == interp_method) { - NearestNeighborInterpolateGrad(output_grad, input_grad, ratio_h, ratio_w, - n, c, out_h, out_w, align_corners, - data_layout); + float ratio_d = 1.f; + int out_d = static_cast(1); + NearestNeighborInterpolateGrad(output_grad, input_grad, ratio_d, ratio_h, + ratio_w, n, c, out_d, out_h, out_w, + align_corners, data_layout); } else if ("bicubic" == interp_method) { BicubicInterpolationGrad(output_grad, input_grad, ratio_h, ratio_w, in_h, in_w, n, c, out_h, out_w, align_corners, @@ -1489,6 +1510,10 @@ static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx, TrilinearInterpolationGrad( output_grad, input_grad, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n, c, out_d, out_h, out_w, align_corners, align_mode, data_layout); + } else if ("nearest" == interp_method) { + NearestNeighborInterpolateGrad(output_grad, input_grad, ratio_d, ratio_h, + ratio_w, n, c, out_d, out_h, out_w, + align_corners, data_layout); } } diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py index 04962a93c11c1e..485558d1b7e281 100755 --- a/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py @@ -23,6 +23,8 @@ import paddle from paddle.nn.functional import interpolate +paddle.enable_static() + def nearest_neighbor_interp_np(X, out_h, @@ -78,7 +80,7 @@ def nearest_neighbor_interp_np(X, if data_layout == "NHWC": out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC - + out = np.expand_dims(out, 2) return out.astype(X.dtype) @@ -117,6 +119,7 @@ def setUp(self): output_np = nearest_neighbor_interp_np( input_np, out_h, out_w, scale_h, scale_w, self.out_size, self.actual_shape, self.align_corners, self.data_layout) + print("input shape:", input_np.shape) self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size @@ -137,6 +140,8 @@ def setUp(self): self.scale = [self.scale[0], self.scale[0]] self.attrs['scale'] = self.scale self.outputs = {'Out': output_np} + print("=========================") + print(output_np.shape) def test_check_output(self): self.check_output() @@ -164,6 +169,7 @@ def init_test_case(self): self.align_corners = True +""" class TestNearestNeighborInterpCase2(TestNearestInterpOp): def init_test_case(self): self.interp_method = 'nearest' @@ -568,6 +574,7 @@ def attr_scale_value(): self.assertRaises(TypeError, attr_scale_type) self.assertRaises(ValueError, attr_scale_value) +""" if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 3dba9505e92c79..655550c594df22 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -221,7 +221,8 @@ def interpolate(x, ValueError: The 'mode' of image_resize can only be 'linear', 'bilinear', 'trilinear', 'bicubic', 'area' or 'nearest' currently. ValueError: 'linear' only support 3-D tensor. - ValueError: 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor. + ValueError: 'bilinear' and 'bicubic' only support 4-D tensor. + ValueError: 'nearest' only support 4-D or 5-D tensor. ValueError: 'trilinear' only support 5-D tensor. ValueError: One of size and scale_factor must not be None. ValueError: size length should be 1 for input 3-D tensor. @@ -276,9 +277,11 @@ def interpolate(x, if resample in ['LINEAR'] and len(x.shape) != 3: raise ValueError("'linear' only support 3-D tensor.") - if resample in ['BILINEAR', 'NEAREST', 'BICUBIC'] and len(x.shape) != 4: - raise ValueError( - "'bilinear', 'bicubic' and 'nearest' only support 4-D tensor.") + if resample in ['NEAREST'] and len(x.shape) != 4 and len(x.shape) != 5: + raise ValueError("'NEAREST' only support 4-D or 5-D tensor.") + + if resample in ['BILINEAR', 'BICUBIC'] and len(x.shape) != 4: + raise ValueError("'bilinear' and 'bicubic' only support 4-D tensor.") if resample == 'TRILINEAR' and len(x.shape) != 5: raise ValueError("'trilinear'only support 5-D tensor.") @@ -332,6 +335,8 @@ def _is_list_or_turple_(data): if resample == 'NEAREST': align_corners = False + if resample in ['NEAREST'] and len(x.shape) == 4: + x = unsqueeze(x, axis=[2]) inputs = {"X": x} attrs = { @@ -466,6 +471,8 @@ def _is_list_or_turple_(data): out = _C_ops.trilinear_interp_v2(x, *dy_attr) elif resample_type == "nearest": out = _C_ops.nearest_interp_v2(x, *dy_attr) + if len(x.shape) == 4: + return squeeze(out, [2]) elif resample_type == "bicubic": out = _C_ops.bicubic_interp_v2(x, *dy_attr) return out @@ -475,6 +482,8 @@ def _is_list_or_turple_(data): inputs=inputs, outputs={"Out": out}, attrs=attrs) + if resample_type == "nearest" and len(x.shape) == 4: + return squeeze(out, [2]) return out diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 22f7f798374d8a..89ff156bded2af 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -359,8 +359,9 @@ class Upsample(Layer): ValueError: The 'mode' of image_resize can only be 'linear', 'bilinear', 'trilinear', 'bicubic', or 'nearest' currently. ValueError: 'linear' only support 3-D tensor. - ValueError: 'bilinear', 'bicubic' and 'nearest' only support 4-D tensor. + ValueError: 'bilinear' and 'bicubic' only support 4-D tensor. ValueError: 'trilinear' only support 5-D tensor. + ValueError: 'nearest' only support 4-D or 5-D tensor. ValueError: One of size and scale_factor must not be None. ValueError: size length should be 1 for input 3-D tensor. ValueError: size length should be 2 for input 4-D tensor. From 97d83575e2c5aa893681e8f3fec2acf840ccc0a2 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Tue, 11 Jan 2022 05:18:17 +0000 Subject: [PATCH 2/4] update nearest3d unittest, test=develop --- paddle/fluid/operators/interpolate_v2_op.cu | 114 +++++++++-- paddle/fluid/operators/interpolate_v2_op.h | 90 +++++++-- .../unittests/test_nearest_interp_v2_op.py | 185 +++++++++++++++--- python/paddle/nn/functional/common.py | 6 - 4 files changed, 333 insertions(+), 62 deletions(-) diff --git a/paddle/fluid/operators/interpolate_v2_op.cu b/paddle/fluid/operators/interpolate_v2_op.cu index 0a0839f6bc017c..3db0fdf5e6da4e 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cu +++ b/paddle/fluid/operators/interpolate_v2_op.cu @@ -25,6 +25,50 @@ using DataLayout = framework::DataLayout; template __global__ void KeNearestNeighborInterpFw( + const T* in, const size_t in_img_h, const size_t in_img_w, + const size_t input_h, const size_t input_w, T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w, + const bool align_corners, const DataLayout data_layout) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + int channel_id, out_img_idy, out_img_idx; + if (data_layout == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idy = (out_id_w % out_img_size) / out_img_w; + out_img_idx = tid % out_img_w; + } else { + out_img_idy = out_id_w / (out_img_w * num_channels); + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + int in_img_idy = (align_corners) + ? static_cast(ratio_h * out_img_idy + 0.5) + : static_cast(ratio_h * out_img_idy); + int in_img_idx = (align_corners) + ? static_cast(ratio_w * out_img_idx + 0.5) + : static_cast(ratio_w * out_img_idx); + + if (data_layout == DataLayout::kNCHW) { + out[tid] = in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + } else { + out[tid] = in[out_id_h * input_w + in_img_idy * in_img_w * num_channels + + in_img_idx * num_channels + channel_id]; + } + } +} + +template +__global__ void KeNearestNeighbor3DInterpFw( const T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w, const size_t input_h, const size_t input_w, T* out, const size_t out_img_d, const size_t out_img_h, const size_t out_img_w, @@ -35,9 +79,8 @@ __global__ void KeNearestNeighborInterpFw( int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; for (; tid < nthreads; tid += stride) { - int out_id_h = tid / output_w; // n - int out_id_w = tid % output_w; // cdhw - + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; int in_img_size = input_w / num_channels; int out_img_size = output_w / num_channels; @@ -81,6 +124,53 @@ __global__ void KeNearestNeighborInterpFw( template __global__ void KeNearestNeighborInterpBw( + T* in, const size_t in_img_h, const size_t in_img_w, const size_t input_h, + const size_t input_w, const T* out, const size_t out_img_h, + const size_t out_img_w, const size_t output_h, const size_t output_w, + const size_t num_channels, const float ratio_h, const float ratio_w, + const bool align_corners, const DataLayout data_layout) { + int nthreads = output_h * output_w; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (; tid < nthreads; tid += stride) { + int out_id_h = tid / output_w; + int out_id_w = tid % output_w; + int in_img_size = input_w / num_channels; + int out_img_size = output_w / num_channels; + + int channel_id, out_img_idy, out_img_idx; + if (data_layout == DataLayout::kNCHW) { + channel_id = out_id_w / out_img_size; + out_img_idy = (out_id_w % out_img_size) / out_img_w; + out_img_idx = tid % out_img_w; + } else { + out_img_idy = out_id_w / (out_img_w * num_channels); + out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels; + channel_id = tid % num_channels; + } + + int in_img_idy = (align_corners) + ? static_cast(ratio_h * out_img_idy + 0.5) + : static_cast(ratio_h * out_img_idy); + int in_img_idx = (align_corners) + ? static_cast(ratio_w * out_img_idx + 0.5) + : static_cast(ratio_w * out_img_idx); + + T* in_pos; + if (data_layout == DataLayout::kNCHW) { + in_pos = &in[out_id_h * input_w + channel_id * in_img_size + + in_img_idy * in_img_w + in_img_idx]; + } else { + in_pos = &in[out_id_h * input_w + in_img_idy * in_img_w * num_channels + + in_img_idx * num_channels + channel_id]; + } + const T out_pos = out[out_id_h * output_w + out_id_w]; + platform::CudaAtomicAdd(in_pos, out_pos); + } +} + +template +__global__ void KeNearestNeighbor3DInterpBw( T* in, const size_t in_img_d, const size_t in_img_h, const size_t in_img_w, const size_t input_h, const size_t input_w, const T* out, const size_t out_img_d, const size_t out_img_h, const size_t out_img_w, @@ -1202,14 +1292,11 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx, platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("nearest" == interp_method) { - float ratio_d = 1.f; - int out_d = static_cast(1); KeNearestNeighborInterpFw< T><<>>( - input_data, in_d, in_h, in_w, n, in_chw, output_data, out_d, out_h, - out_w, n, out_chw, c, ratio_d, ratio_h, ratio_w, align_corners, - data_layout); + input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n, + out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } else if ("bilinear" == interp_method) { dim3 thread_num = config.thread_per_block; #ifdef WITH_NV_JETSON @@ -1402,7 +1489,7 @@ static void Interpolate3DCUDAFwd(const framework::ExecutionContext& ctx, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, align_mode, data_layout); } else if ("nearest" == interp_method) { - KeNearestNeighborInterpFw< + KeNearestNeighbor3DInterpFw< T><<>>( input_data, in_d, in_h, in_w, n, in_cdhw, output_data, out_d, out_h, @@ -1634,14 +1721,11 @@ static void Interpolate2DCUDABwd(const framework::ExecutionContext& ctx, platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum); if ("nearest" == interp_method) { - float ratio_d = 1.f; - int out_d = static_cast(1); KeNearestNeighborInterpBw< T><<>>( - input_grad_data, in_d, in_h, in_w, n, in_chw, output_grad_data, out_d, - out_h, out_w, n, out_chw, c, ratio_d, ratio_h, ratio_w, align_corners, - data_layout); + input_grad_data, in_h, in_w, n, in_chw, output_grad_data, out_h, out_w, + n, out_chw, c, ratio_h, ratio_w, align_corners, data_layout); } else if ("bilinear" == interp_method) { const T align_type_value = (align_mode == 0 && !align_corners) ? 0.5f : 0; bool is_nchw = (data_layout == DataLayout::kNCHW) ? true : false; @@ -1837,7 +1921,7 @@ static void Interpolate3DCUDABwd(const framework::ExecutionContext& ctx, out_h, out_w, n, out_cdhw, c, ratio_d, ratio_h, ratio_w, align_corners, align_mode, data_layout); } else if ("nearest" == interp_method) { - KeNearestNeighborInterpBw< + KeNearestNeighbor3DInterpBw< T><<>>( input_grad_data, in_d, in_h, in_w, n, in_cdhw, output_grad_data, out_d, diff --git a/paddle/fluid/operators/interpolate_v2_op.h b/paddle/fluid/operators/interpolate_v2_op.h index ebc45114d9eb34..0af799eca0c55c 100644 --- a/paddle/fluid/operators/interpolate_v2_op.h +++ b/paddle/fluid/operators/interpolate_v2_op.h @@ -93,12 +93,40 @@ inline void ExtractNCDWH(const framework::DDim& dims, template static void NearestNeighborInterpolate(const Tensor& input, Tensor* output, - const float ratio_d, const float ratio_h, - const float ratio_w, const int n, - const int c, const int out_d, + const float ratio_h, const float ratio_w, + const int n, const int c, const int out_h, const int out_w, const bool align_corners, const DataLayout& data_layout) { + auto input_t = EigenTensor::From(input); + auto output_t = EigenTensor::From(*output); + for (int k = 0; k < out_h; k++) { // loop for images + int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) + : static_cast(ratio_h * k); + + for (int l = 0; l < out_w; l++) { + int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) + : static_cast(ratio_w * l); + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + if (data_layout == DataLayout::kNCHW) { + output_t(i, j, k, l) = input_t(i, j, in_k, in_l); + } else { + output_t(i, k, l, j) = input_t(i, in_k, in_l, j); + } + } + } + } + } +} + +template +static void NearestNeighbor3DInterpolate( + const Tensor& input, Tensor* output, const float ratio_d, + const float ratio_h, const float ratio_w, const int n, const int c, + const int out_d, const int out_h, const int out_w, const bool align_corners, + const DataLayout& data_layout) { auto input_t = EigenTensor::From(input); auto output_t = EigenTensor::From(*output); for (int d = 0; d < out_d; d++) { // loop for images @@ -562,6 +590,35 @@ static void BicubicInterpolation(const Tensor& input, Tensor* output, template static void NearestNeighborInterpolateGrad( + const Tensor& output_grad, Tensor* input_grad, const float ratio_h, + const float ratio_w, const int n, const int c, const int out_h, + const int out_w, const bool align_corners, const DataLayout data_layout) { + auto input_grad_t = EigenTensor::From(*input_grad); + auto output_grad_t = EigenTensor::From(output_grad); + + for (int k = 0; k < out_h; k++) { // loop for images + int in_k = (align_corners) ? static_cast(ratio_h * k + 0.5) + : static_cast(ratio_h * k); + + for (int l = 0; l < out_w; l++) { + int in_l = (align_corners) ? static_cast(ratio_w * l + 0.5) + : static_cast(ratio_w * l); + + for (int i = 0; i < n; i++) { // loop for batches + for (int j = 0; j < c; j++) { // loop for channels + if (data_layout == DataLayout::kNCHW) { + input_grad_t(i, j, in_k, in_l) += output_grad_t(i, j, k, l); + } else { + input_grad_t(i, in_k, in_l, j) += output_grad_t(i, k, l, j); + } + } + } + } + } +} + +template +static void NearestNeighbor3DInterpolateGrad( const Tensor& output_grad, Tensor* input_grad, const float ratio_d, const float ratio_h, const float ratio_w, const int n, const int c, const int out_d, const int out_h, const int out_w, const bool align_corners, @@ -990,11 +1047,8 @@ static void Interpolate2DCPUFwd(const framework::ExecutionContext& ctx, out_h, out_w, align_corners, align_mode, data_layout); } else if ("nearest" == interp_method) { - float ratio_d = 1.f; - int out_d = static_cast(1); - NearestNeighborInterpolate(input, output, ratio_d, ratio_h, ratio_w, n, - c, out_d, out_h, out_w, align_corners, - data_layout); + NearestNeighborInterpolate(input, output, ratio_h, ratio_w, n, c, out_h, + out_w, align_corners, data_layout); } else if ("bicubic" == interp_method) { BicubicInterpolation(input, output, ratio_h, ratio_w, in_h, in_w, n, c, out_h, out_w, align_corners, data_layout); @@ -1153,9 +1207,9 @@ static void Interpolate3DCPUFwd(const framework::ExecutionContext& ctx, in_h, in_w, n, c, out_d, out_h, out_w, align_corners, align_mode, data_layout); } else if ("nearest" == interp_method) { - NearestNeighborInterpolate(input, output, ratio_d, ratio_h, ratio_w, n, - c, out_d, out_h, out_w, align_corners, - data_layout); + NearestNeighbor3DInterpolate(input, output, ratio_d, ratio_h, ratio_w, n, + c, out_d, out_h, out_w, align_corners, + data_layout); } } @@ -1357,11 +1411,9 @@ static void Interpolate2DCPUBwd(const framework::ExecutionContext& ctx, in_h, in_w, n, c, out_h, out_w, align_corners, align_mode, data_layout); } else if ("nearest" == interp_method) { - float ratio_d = 1.f; - int out_d = static_cast(1); - NearestNeighborInterpolateGrad(output_grad, input_grad, ratio_d, ratio_h, - ratio_w, n, c, out_d, out_h, out_w, - align_corners, data_layout); + NearestNeighborInterpolateGrad(output_grad, input_grad, ratio_h, ratio_w, + n, c, out_h, out_w, align_corners, + data_layout); } else if ("bicubic" == interp_method) { BicubicInterpolationGrad(output_grad, input_grad, ratio_h, ratio_w, in_h, in_w, n, c, out_h, out_w, align_corners, @@ -1511,9 +1563,9 @@ static void Interpolate3DCPUBwd(const framework::ExecutionContext& ctx, output_grad, input_grad, ratio_d, ratio_h, ratio_w, in_d, in_h, in_w, n, c, out_d, out_h, out_w, align_corners, align_mode, data_layout); } else if ("nearest" == interp_method) { - NearestNeighborInterpolateGrad(output_grad, input_grad, ratio_d, ratio_h, - ratio_w, n, c, out_d, out_h, out_w, - align_corners, data_layout); + NearestNeighbor3DInterpolateGrad(output_grad, input_grad, ratio_d, + ratio_h, ratio_w, n, c, out_d, out_h, + out_w, align_corners, data_layout); } } diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py index 485558d1b7e281..79574134e9a003 100755 --- a/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py @@ -80,7 +80,80 @@ def nearest_neighbor_interp_np(X, if data_layout == "NHWC": out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC - out = np.expand_dims(out, 2) + # out = np.expand_dims(out, 2) + return out.astype(X.dtype) + + +def nearest_neighbor_interp3d_np(X, + out_d, + out_h, + out_w, + scale_d=0, + scale_h=0, + scale_w=0, + out_size=None, + actual_shape=None, + align_corners=True, + data_layout='NCHW'): + """nearest neighbor interpolation implement in shape [N, C, H, W]""" + if data_layout == "NHWC": + X = np.transpose(X, (0, 4, 1, 2, 3)) # NDHWC => NCDHW + if out_size is not None: + out_d = out_size[0] + out_h = out_size[1] + out_w = out_size[2] + if actual_shape is not None: + out_d = actual_shape[0] + out_h = actual_shape[1] + out_w = actual_shape[2] + n, c, in_d, in_h, in_w = X.shape + + ratio_d = ratio_h = ratio_w = 0.0 + if (out_d > 1): + if (align_corners): + ratio_d = (in_d - 1.0) / (out_d - 1.0) + else: + if scale_d > 0: + ratio_d = 1.0 / scale_d + else: + ratio_d = 1.0 * in_d / out_d + if (out_h > 1): + if (align_corners): + ratio_h = (in_h - 1.0) / (out_h - 1.0) + else: + if scale_h > 0: + ratio_h = 1.0 / scale_h + else: + ratio_h = 1.0 * in_h / out_h + if (out_w > 1): + if (align_corners): + ratio_w = (in_w - 1.0) / (out_w - 1.0) + else: + if scale_w > 0: + ratio_w = 1.0 / scale_w + else: + ratio_w = 1.0 * in_w / out_w + out = np.zeros((n, c, out_d, out_h, out_w)) + + if align_corners: + for d in range(out_d): + in_d = int(ratio_d * d + 0.5) + for i in range(out_h): + in_i = int(ratio_h * i + 0.5) + for j in range(out_w): + in_j = int(ratio_w * j + 0.5) + out[:, :, d, i, j] = X[:, :, in_d, in_i, in_j] + else: + for d in range(out_d): + in_d = int(ratio_d * d) + for i in range(out_h): + in_i = int(ratio_h * i) + for j in range(out_w): + in_j = int(ratio_w * j) + out[:, :, d, i, j] = X[:, :, in_d, in_i, in_j] + + if data_layout == "NDHWC": + out = np.transpose(out, (0, 2, 3, 4, 1)) # NCDHW => NDHWC return out.astype(X.dtype) @@ -93,45 +166,81 @@ def setUp(self): self.op_type = "nearest_interp_v2" input_np = np.random.random(self.input_shape).astype("float64") - if self.data_layout == "NCHW": + if self.data_layout == "NCHW" and len(self.input_shape) == 4: + in_d = 1 in_h = self.input_shape[2] in_w = self.input_shape[3] else: + in_d = 1 in_h = self.input_shape[1] in_w = self.input_shape[2] + + if self.data_layout == "NCDHW" and len(self.input_shape) == 5: + in_d = self.input_shape[2] + in_h = self.input_shape[3] + in_w = self.input_shape[4] + else: + in_d = self.input_shape[1] + in_h = self.input_shape[2] + in_w = self.input_shape[3] + scale_d = 0 scale_h = 0 scale_w = 0 if self.scale: if isinstance(self.scale, float) or isinstance(self.scale, int): if self.scale > 0: - scale_h = scale_w = float(self.scale) + scale_d = scale_h = scale_w = float(self.scale) if isinstance(self.scale, list) and len(self.scale) == 1: - scale_w = scale_h = self.scale[0] + scale_d = scale_w = scale_h = self.scale[0] elif isinstance(self.scale, list) and len(self.scale) > 1: - scale_w = self.scale[1] - scale_h = self.scale[0] + if len(self.scale) == 5: + scale_w = self.scale[2] + scale_h = self.scale[1] + scale_d = self.scale[0] + else: + scale_w = self.scale[1] + scale_h = self.scale[0] + out_h = int(in_h * scale_h) out_w = int(in_w * scale_w) + out_d = int(in_d * scale_d) else: + if len(self.input_shape) == 5: + out_d = self.out_d out_h = self.out_h out_w = self.out_w - output_np = nearest_neighbor_interp_np( - input_np, out_h, out_w, scale_h, scale_w, self.out_size, - self.actual_shape, self.align_corners, self.data_layout) - print("input shape:", input_np.shape) + if len(self.input_shape) == 4: + output_np = nearest_neighbor_interp_np( + input_np, out_h, out_w, scale_h, scale_w, self.out_size, + self.actual_shape, self.align_corners, self.data_layout) + elif len(self.input_shape) == 5: + output_np = nearest_neighbor_interp3d_np( + input_np, out_d, out_h, out_w, scale_d, scale_h, scale_w, + self.out_size, self.actual_shape, self.align_corners, + self.data_layout) self.inputs = {'X': input_np} if self.out_size is not None: self.inputs['OutSize'] = self.out_size if self.actual_shape is not None: self.inputs['OutSize'] = self.actual_shape - self.attrs = { - 'out_h': self.out_h, - 'out_w': self.out_w, - 'interp_method': self.interp_method, - 'align_corners': self.align_corners, - 'data_layout': self.data_layout - } + if len(self.input_shape) == 5: + self.attrs = { + 'out_d': self.out_d, + 'out_h': self.out_h, + 'out_w': self.out_w, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'data_layout': self.data_layout + } + else: + self.attrs = { + 'out_h': self.out_h, + 'out_w': self.out_w, + 'interp_method': self.interp_method, + 'align_corners': self.align_corners, + 'data_layout': self.data_layout + } if self.scale: if isinstance(self.scale, float) or isinstance(self.scale, int): if self.scale > 0: @@ -140,8 +249,6 @@ def setUp(self): self.scale = [self.scale[0], self.scale[0]] self.attrs['scale'] = self.scale self.outputs = {'Out': output_np} - print("=========================") - print(output_np.shape) def test_check_output(self): self.check_output() @@ -162,14 +269,14 @@ def init_test_case(self): class TestNearestNeighborInterpCase1(TestNearestInterpOp): def init_test_case(self): self.interp_method = 'nearest' - self.input_shape = [4, 1, 7, 8] + self.input_shape = [4, 1, 1, 7, 8] + self.out_d = 1 self.out_h = 1 self.out_w = 1 self.scale = 0. self.align_corners = True -""" class TestNearestNeighborInterpCase2(TestNearestInterpOp): def init_test_case(self): self.interp_method = 'nearest' @@ -372,6 +479,18 @@ def init_test_case(self): self.align_corners = True +class TestNearestNeighbor3DInterp(TestNearestInterpOp): + def init_test_case(self): + self.interp_method = 'nearest' + self.input_shape = [3, 2, 4, 7, 5] + self.out_d = 8 + self.out_h = 64 + self.out_w = 32 + self.scale = [4.0, 2.0, 3.0] + self.out_size = np.array([8, 66, 40]).astype("int32") + self.align_corners = True + + class TestNearestInterpOp_attr_tensor(OpTest): def setUp(self): self.out_size = None @@ -555,6 +674,29 @@ def test_case(self): self.assertTrue(np.allclose(out.numpy(), expect_res)) +class TestNearestInterp3DOpAPI_dy(unittest.TestCase): + def test_case(self): + import paddle + if core.is_compiled_with_cuda(): + place = core.CUDAPlace(0) + else: + place = core.CPUPlace() + with fluid.dygraph.guard(place): + input_data = np.random.random((2, 2, 6, 6, 6)).astype("int64") + scale_np = np.array([2, 2, 2]).astype("int64") + input_x = paddle.to_tensor(input_data) + scale = paddle.to_tensor(scale_np) + expect_res = nearest_neighbor_interp3d_np( + input_data, out_d=12, out_h=12, out_w=12, align_corners=False) + out = interpolate( + x=input_x, + scale_factor=scale, + mode="nearest", + align_corners=False, + data_format="NCDHW") + self.assertTrue(np.allclose(out.numpy(), expect_res)) + + class TestNearestInterpException(unittest.TestCase): def test_exception(self): input = fluid.data(name="input", shape=[1, 3, 6, 6], dtype="float32") @@ -574,7 +716,6 @@ def attr_scale_value(): self.assertRaises(TypeError, attr_scale_type) self.assertRaises(ValueError, attr_scale_value) -""" if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 655550c594df22..5a010ad2f20c55 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -335,8 +335,6 @@ def _is_list_or_turple_(data): if resample == 'NEAREST': align_corners = False - if resample in ['NEAREST'] and len(x.shape) == 4: - x = unsqueeze(x, axis=[2]) inputs = {"X": x} attrs = { @@ -471,8 +469,6 @@ def _is_list_or_turple_(data): out = _C_ops.trilinear_interp_v2(x, *dy_attr) elif resample_type == "nearest": out = _C_ops.nearest_interp_v2(x, *dy_attr) - if len(x.shape) == 4: - return squeeze(out, [2]) elif resample_type == "bicubic": out = _C_ops.bicubic_interp_v2(x, *dy_attr) return out @@ -482,8 +478,6 @@ def _is_list_or_turple_(data): inputs=inputs, outputs={"Out": out}, attrs=attrs) - if resample_type == "nearest" and len(x.shape) == 4: - return squeeze(out, [2]) return out From ad8a5a62d166e7399bfd29965f5fabd43f286480 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Tue, 11 Jan 2022 08:47:15 +0000 Subject: [PATCH 3/4] fix approve ci, test=develop --- .../fluid/tests/unittests/test_nearest_interp_v2_op.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py index 79574134e9a003..445a5c064403cb 100755 --- a/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py @@ -712,9 +712,19 @@ def attr_scale_type(): def attr_scale_value(): out = fluid.layers.resize_nearest(input, scale=-0.3) + def input_shape_error(): + x = fluid.data(name="input", shape=[1, 3], dtype="float32") + out = fluid.layers.resize_nearest(x, scale='scale') + + def mode_error(): + x = fluid.data(name="input", shape=[1, 3], dtype="float32") + out = fluid.layers.resize_bilinear(x, scale='scale') + self.assertRaises(ValueError, attr_data_format) self.assertRaises(TypeError, attr_scale_type) self.assertRaises(ValueError, attr_scale_value) + self.assertRaises(ValueError, input_shape_error) + self.assertRaises(ValueError, mode_error) if __name__ == "__main__": From 2f78a51e6fc8c79b7a3daef49e6d1905c192da00 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 12 Jan 2022 03:16:48 +0000 Subject: [PATCH 4/4] fix approve ci, test=develop --- .../fluid/tests/unittests/test_nearest_interp_v2_op.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py b/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py index 445a5c064403cb..e2ac98f7c9f1f6 100755 --- a/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_nearest_interp_v2_op.py @@ -699,6 +699,7 @@ def test_case(self): class TestNearestInterpException(unittest.TestCase): def test_exception(self): + import paddle input = fluid.data(name="input", shape=[1, 3, 6, 6], dtype="float32") def attr_data_format(): @@ -713,12 +714,13 @@ def attr_scale_value(): out = fluid.layers.resize_nearest(input, scale=-0.3) def input_shape_error(): - x = fluid.data(name="input", shape=[1, 3], dtype="float32") - out = fluid.layers.resize_nearest(x, scale='scale') + x = paddle.randn([1, 3]) + out = paddle.nn.functional.interpolate(x, scale_factor='scale') def mode_error(): - x = fluid.data(name="input", shape=[1, 3], dtype="float32") - out = fluid.layers.resize_bilinear(x, scale='scale') + x = paddle.randn([1, 3]) + out = paddle.nn.functional.interpolate( + x, scale_factor='scale', mode="BILINEAR") self.assertRaises(ValueError, attr_data_format) self.assertRaises(TypeError, attr_scale_type)