Skip to content

Commit

Permalink
Add MaxUnPool3D op and MaxUnPool1D op (#38716)
Browse files Browse the repository at this point in the history
* add maxunpool3d op

* update doc for maxunpool3d op

* update doc for maxunpool3d op

* update doc for maxunpool3d op

* update sample code for maxunpool3d

* add maxunpool1d op

* update some code for maxunpool1d
  • Loading branch information
andyjiang1116 authored Jan 10, 2022
1 parent 2238a53 commit 7e31542
Show file tree
Hide file tree
Showing 13 changed files with 1,287 additions and 6 deletions.
93 changes: 92 additions & 1 deletion paddle/fluid/operators/math/unpooling.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -96,10 +96,101 @@ class Unpool2dMaxGradFunctor<platform::CPUDeviceContext, T> {
}
}
};

template <typename T>
class Unpool3dMaxFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output->dims()[1];
const int output_depth = output->dims()[2];
const int output_height = output->dims()[3];
const int output_width = output->dims()[4];
int input_feasize = input_depth * input_height * input_width;
int output_feasize = output_depth * output_height * output_width;
const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace());
for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];

PADDLE_ENFORCE_LT(
index, output_feasize,
platform::errors::InvalidArgument(
"index should less than output tensor depth * output tensor "
"height "
"* output tensor width. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
index, output_feasize, index, output_feasize));
output_data[index] = input_data[i];
}
input_data += input_feasize;
indices_data += input_feasize;
output_data += output_feasize;
}
}
}
};
template <class T>
class Unpool3dMaxGradFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
int input_feasize = input_depth * input_height * input_width;
int output_feasize = output_depth * output_height * output_width;
const int* indices_data = indices.data<int>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());

for (int b = 0; b < batch_size; ++b) {
for (int c = 0; c < output_channels; ++c) {
for (int i = 0; i < input_feasize; ++i) {
int index = indices_data[i];
PADDLE_ENFORCE_LT(
index, output_feasize,
platform::errors::InvalidArgument(
"index should less than output tensor depth * output tensor "
"height "
"* output tensor width. Expected %ld < %ld, but got "
"%ld >= %ld. Please check input value.",
index, output_feasize, index, output_feasize));
input_grad_data[i] = output_grad_data[index];
}
input_grad_data += input_feasize;
indices_data += input_feasize;
output_grad_data += output_feasize;
}
}
}
};

template class Unpool2dMaxGradFunctor<platform::CPUDeviceContext, float>;
template class Unpool2dMaxGradFunctor<platform::CPUDeviceContext, double>;
template class Unpool2dMaxFunctor<platform::CPUDeviceContext, float>;
template class Unpool2dMaxFunctor<platform::CPUDeviceContext, double>;
template class Unpool3dMaxGradFunctor<platform::CPUDeviceContext, float>;
template class Unpool3dMaxGradFunctor<platform::CPUDeviceContext, double>;
template class Unpool3dMaxFunctor<platform::CPUDeviceContext, float>;
template class Unpool3dMaxFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
113 changes: 112 additions & 1 deletion paddle/fluid/operators/math/unpooling.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -51,6 +51,45 @@ __global__ void KernelUnpool2dMaxGrad(
/*
* All tensors are in NCHW format.
*/

template <typename T>
__global__ void KernelUnpool3dMax(const int nthreads, const T* input_data,
const int* indices_data,
const int input_depth, const int input_height,
const int input_width, const int channels,
T* output_data, const int output_depth,
const int output_height,
const int output_width) {
CUDA_KERNEL_LOOP(linearIndex, nthreads) {
int c = (linearIndex / input_depth / input_width / input_height) % channels;
int n = linearIndex / input_depth / input_width / input_height / channels;
output_data +=
(n * channels + c) * output_depth * output_height * output_width;
int maxind = indices_data[linearIndex];
output_data[maxind] = input_data[linearIndex];
}
}

template <typename T>
__global__ void KernelUnpool3dMaxGrad(
const int nthreads, const T* input_data, const int* indices_data,
const int input_depth, const int input_height, const int input_width,
const int channels, const T* output_data, const T* output_grad,
const int output_depth, const int output_height, const int output_width,
T* input_grad) {
CUDA_KERNEL_LOOP(linearIndex, nthreads) {
int c = (linearIndex / input_depth / input_width / input_height) % channels;
int n = linearIndex / input_depth / input_width / input_height / channels;
output_grad +=
(n * channels + c) * output_depth * output_height * output_width;
int maxind = indices_data[linearIndex];
input_grad[linearIndex] = output_grad[maxind];
}
}
/*
* All tensors are in NCDHW format.
*/

template <typename T>
class Unpool2dMaxFunctor<platform::CUDADeviceContext, T> {
public:
Expand Down Expand Up @@ -112,10 +151,82 @@ class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, T> {
output_width, input_grad_data);
}
};

template <typename T>
class Unpool3dMaxFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices, framework::Tensor* output) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output->dims()[1];
const int output_depth = output->dims()[2];
const int output_height = output->dims()[3];
const int output_width = output->dims()[4];
const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>();
T* output_data = output->mutable_data<T>(context.GetPlace());
#ifdef __HIPCC__
int threads = 256;
#else
int threads = 1024;
#endif
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool3dMax<T><<<grid, threads, 0, context.stream()>>>(
input.numel(), input_data, indices_data, input_depth, input_height,
input_width, output_channels, output_data, output_depth, output_height,
output_width);
}
};
/*
* All tensors are in NCDHW format.
*/
template <typename T>
class Unpool3dMaxGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor* input_grad) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
const int input_width = input.dims()[4];
const int output_channels = output.dims()[1];
const int output_depth = output.dims()[2];
const int output_height = output.dims()[3];
const int output_width = output.dims()[4];
const T* input_data = input.data<T>();
const int* indices_data = indices.data<int>();
const T* output_data = output.data<T>();
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
#ifdef __HIPCC__
int threads = 256;
#else
int threads = 1024;
#endif
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool3dMaxGrad<T><<<grid, threads, 0, context.stream()>>>(
input.numel(), input_data, indices_data, input_depth, input_height,
input_width, output_channels, output_data, output_grad_data,
output_depth, output_height, output_width, input_grad_data);
}
};

template class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, float>;
template class Unpool2dMaxGradFunctor<platform::CUDADeviceContext, double>;
template class Unpool2dMaxFunctor<platform::CUDADeviceContext, float>;
template class Unpool2dMaxFunctor<platform::CUDADeviceContext, double>;
template class Unpool3dMaxGradFunctor<platform::CUDADeviceContext, float>;
template class Unpool3dMaxGradFunctor<platform::CUDADeviceContext, double>;
template class Unpool3dMaxFunctor<platform::CUDADeviceContext, float>;
template class Unpool3dMaxFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
18 changes: 17 additions & 1 deletion paddle/fluid/operators/math/unpooling.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -33,6 +33,22 @@ class Unpool2dMaxGradFunctor {
const framework::Tensor& output_grad,
framework::Tensor* input_grad);
};

template <typename DeviceContext, typename T>
class Unpool3dMaxFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& indices, framework::Tensor* output);
};
template <typename DeviceContext, class T>
class Unpool3dMaxGradFunctor {
public:
void operator()(const DeviceContext& context, const framework::Tensor& input,
const framework::Tensor& indices,
const framework::Tensor& output,
const framework::Tensor& output_grad,
framework::Tensor* input_grad);
};
} // namespace math
} // namespace operators
} // namespace paddle
Loading

0 comments on commit 7e31542

Please sign in to comment.