From 7f09cd156e67299cc11fd6818fb25c48e439f69d Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Thu, 7 Feb 2019 17:04:32 -0800 Subject: [PATCH 01/10] CPU implementation without Kernel launch/map --- src/operator/image/image_random-inl.h | 48 ++++++++++++--------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 448016341f21..a81c31b2f714 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -78,37 +78,33 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs, } // Operator Implementation - -template -struct totensor_forward { - template - MSHADOW_XINLINE static void Map(uint32_t c, float* out_data, const DType* in_data, - const int length, const int channel, const int step, - const float normalize_factor = 255.0f) { - #pragma omp parallel for +template +inline void ToTensor(float* out_data, const DType* in_data, + const int length, + const int channels, + const int step = 0, + const float normalize_factor = 255.0f) { + + #pragma omp parallel for collapse(2) + for (int c = 0; c < channels; ++c) { for (int i = 0; i < length; ++i) { KERNEL_ASSIGN(out_data[step + c*length + i], req, - (in_data[step + i*channel + c]) / normalize_factor); + (in_data[step + i*channels + c]) / normalize_factor); } } -}; - -template -void ToTensorImpl(const OpContext &ctx, - const std::vector &inputs, - const std::vector &outputs, - const std::vector &req, - const int length, - const uint32_t channel, - const int step = 0) { - mshadow::Stream *s = ctx.get_stream(); +} +inline void ToTensorImpl(const std::vector &inputs, + const std::vector &outputs, + const std::vector &req, + const int length, + const int channel, + const int step = 0) { MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { float* output = outputs[0].dptr(); DType* input = inputs[0].dptr(); - mxnet_op::Kernel, xpu>::Launch( - s, channel, output, input, length, channel, step); + ToTensor(output, input, length, channel, step); }); }); } @@ -129,18 +125,18 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs, // 3D Input - (h, w, c) if (inputs[0].ndim() == 3) { const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; - const uint32_t channel = inputs[0].shape_[2]; - ToTensorImpl(ctx, inputs, outputs, req, length, channel); + const int channel = (int)inputs[0].shape_[2]; + ToTensorImpl(inputs, outputs, req, length, channel); } else if (inputs[0].ndim() == 4) { // 4D input (n, h, w, c) const int batch_size = inputs[0].shape_[0]; const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; - const uint32_t channel = inputs[0].shape_[3]; + const int channel = (int)inputs[0].shape_[3]; const int step = channel * length; #pragma omp parallel for for (auto n = 0; n < batch_size; ++n) { - ToTensorImpl(ctx, inputs, outputs, req, length, channel, n*step); + ToTensorImpl(inputs, outputs, req, length, channel, n*step); } } } From 68eb95d3150a039aaf93ad23bcb5083fd337912f Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Fri, 8 Feb 2019 13:48:25 -0800 Subject: [PATCH 02/10] Optimal CUDA support for 3D ToTensor operator --- src/operator/image/image_random-inl.h | 37 ++++++++++++--- src/operator/image/image_random.cu | 65 +++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 7 deletions(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index a81c31b2f714..2c1fca1177d2 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -43,8 +43,17 @@ namespace mxnet { namespace op { namespace image { -// There are no parameters for this operator. -// Hence, no arameter registration. +using namespace mshadow; + +#if MXNET_USE_CUDA +// NOTE: Kernel launch/map was extremely costly. +// Hence, we use separate CUDA kernels for these operators. +template +void ToTensorImplCUDA(mshadow::Stream *s, + const T1 input, + const T2 output, + const int req); +#endif // MXNET_USE_CUDA // Shape and Type inference for image to tensor operator inline bool ToTensorShape(const nnvm::NodeAttrs& attrs, @@ -84,7 +93,6 @@ inline void ToTensor(float* out_data, const DType* in_data, const int channels, const int step = 0, const float normalize_factor = 255.0f) { - #pragma omp parallel for collapse(2) for (int c = 0; c < channels; ++c) { for (int i = 0; i < length; ++i) { @@ -119,19 +127,34 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); + // We do not use temp buffer when performance the operation. + // Hence, this check is necessary. CHECK_EQ(req[0], kWriteTo) << "`to_tensor` does not support inplace updates"; - // 3D Input - (h, w, c) - if (inputs[0].ndim() == 3) { + if (std::is_same::value) { + #if MXNET_USE_CUDA + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + if (inputs[0].ndim() == 3) { + Tensor input = inputs[0].get(s); + Tensor output = outputs[0].get(s); + ToTensorImplCUDA, Tensor>(s, input, output, req_type); + } + }); + }); +#endif // MXNET_USE_CUDA + } else if (inputs[0].ndim() == 3) { + // 3D Input - (h, w, c) const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; - const int channel = (int)inputs[0].shape_[2]; + const int channel = static_cast(inputs[0].shape_[2]); ToTensorImpl(inputs, outputs, req, length, channel); } else if (inputs[0].ndim() == 4) { // 4D input (n, h, w, c) const int batch_size = inputs[0].shape_[0]; const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; - const int channel = (int)inputs[0].shape_[3]; + const int channel = static_cast(inputs[0].shape_[3]); const int step = channel * length; #pragma omp parallel for diff --git a/src/operator/image/image_random.cu b/src/operator/image/image_random.cu index 5f9aff27e85b..ad07e43b2fc7 100644 --- a/src/operator/image/image_random.cu +++ b/src/operator/image/image_random.cu @@ -21,6 +21,7 @@ * \file image_random.cu * \brief GPU Implementation of image transformation operators */ +#include #include "./image_random-inl.h" #include "../elemwise_op_common.h" @@ -28,6 +29,70 @@ namespace mxnet { namespace op { namespace image { +using namespace mshadow; + +template +__global__ void ToTensorCudaKernel(const Tensor input, + const Tensor output, + const int req, + int N, int H, int W, int C, + const float normalize_factor = 255.0f) { + // We process one image per thread block. + // In 3D case, we have only 1 block i.e., blockIdx.x + // We do not use it. + /* + const int n = blockIdx.x; + const int stride = H*W*C; + + // Get pointer to my blocks image + int step = 0; + if (N > 0) { + step = n * stride; + } + */ + for (int c = 0; c < C; ++c) { + for (int h = threadIdx.y; h < H; h += blockDim.y) { + for (int w = threadIdx.x; w < W; w += blockDim.x) { + KERNEL_ASSIGN(output[c][h][w], req, + input[h][w][c] / normalize_factor); + } + } + } +} + +template +void ToTensorImplCUDA(mshadow::Stream *s, + const T1 input, + const T2 output, + const int req, + const float normalize_factor = 255.0f) { + int blocks, H, W, C, N; + cudaStream_t stream = mshadow::Stream::GetStream(s); + if (std::is_same>::value) { + // 3D Input - (H, W, C) + N = 0; + H = input.size(0); + W = input.size(1); + C = input.size(2); + blocks = 1; + } /*else { + // 4D Input - (N, H, W, C) + N = input.size()[0]; + H = input.size()[1]; + W = input.size()[2]; + C = input.size()[3]; + // blocks = N > 0 ? N : 1; + blocks = N; + }*/ + // One block per image. + // Number of threads = (32, 32) is optimal, because, + // computation is minimal and overhead of CUDA preparing + // all threads is minimal. + ToTensorCudaKernel<<>>(input, + output, req, N, H, W, C, normalize_factor); + MSHADOW_CUDA_POST_KERNEL_CHECK(ToTensorCudaKernel); +} + NNVM_REGISTER_OP(_image_to_tensor) .set_attr("FCompute", ToTensorOpForward); From d1c6faadd1dfd30706629dc553e2897c5de21a6e Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Fri, 8 Feb 2019 18:16:39 -0800 Subject: [PATCH 03/10] Add CUDA kernel for 4D inputs --- src/operator/image/image_random-inl.h | 30 +++++++++---- src/operator/image/image_random.cu | 64 +++++++++++++++++---------- 2 files changed, 62 insertions(+), 32 deletions(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 2c1fca1177d2..c2214461308c 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -52,7 +52,8 @@ template void ToTensorImplCUDA(mshadow::Stream *s, const T1 input, const T2 output, - const int req); + const int req, + const float normalize_factor); #endif // MXNET_USE_CUDA // Shape and Type inference for image to tensor operator @@ -89,10 +90,10 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs, // Operator Implementation template inline void ToTensor(float* out_data, const DType* in_data, - const int length, - const int channels, - const int step = 0, - const float normalize_factor = 255.0f) { + const int length, + const int channels, + const float normalize_factor, + const int step = 0) { #pragma omp parallel for collapse(2) for (int c = 0; c < channels; ++c) { for (int i = 0; i < length; ++i) { @@ -107,6 +108,7 @@ inline void ToTensorImpl(const std::vector &inputs, const std::vector &req, const int length, const int channel, + const float normalize_factor, const int step = 0) { MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { @@ -132,6 +134,8 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs, CHECK_EQ(req[0], kWriteTo) << "`to_tensor` does not support inplace updates"; + const float normalize_factor = 255.0f; + if (std::is_same::value) { #if MXNET_USE_CUDA mshadow::Stream *s = ctx.get_stream(); @@ -140,16 +144,23 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs, if (inputs[0].ndim() == 3) { Tensor input = inputs[0].get(s); Tensor output = outputs[0].get(s); - ToTensorImplCUDA, Tensor>(s, input, output, req_type); + ToTensorImplCUDA, Tensor> + (s, input, output, req_type, normalize_factor); + } else { + Tensor input = inputs[0].get(s); + Tensor output = outputs[0].get(s); + ToTensorImplCUDA, Tensor> + (s, input, output, req_type, normalize_factor); } }); }); -#endif // MXNET_USE_CUDA + #endif // MXNET_USE_CUDA } else if (inputs[0].ndim() == 3) { // 3D Input - (h, w, c) const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; const int channel = static_cast(inputs[0].shape_[2]); - ToTensorImpl(inputs, outputs, req, length, channel); + ToTensorImpl(inputs, outputs, req, length, + channel, normalize_factor); } else if (inputs[0].ndim() == 4) { // 4D input (n, h, w, c) const int batch_size = inputs[0].shape_[0]; @@ -159,7 +170,8 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs, #pragma omp parallel for for (auto n = 0; n < batch_size; ++n) { - ToTensorImpl(inputs, outputs, req, length, channel, n*step); + ToTensorImpl(inputs, outputs, req, length, channel, + normalize_factor, n*step); } } } diff --git a/src/operator/image/image_random.cu b/src/operator/image/image_random.cu index ad07e43b2fc7..6fe53832a89e 100644 --- a/src/operator/image/image_random.cu +++ b/src/operator/image/image_random.cu @@ -31,25 +31,19 @@ namespace image { using namespace mshadow; +// ToTensor Kernel for 3D input template __global__ void ToTensorCudaKernel(const Tensor input, const Tensor output, const int req, - int N, int H, int W, int C, - const float normalize_factor = 255.0f) { + const int N, + const int H, + const int W, + const int C, + const float normalize_factor) { // We process one image per thread block. // In 3D case, we have only 1 block i.e., blockIdx.x // We do not use it. - /* - const int n = blockIdx.x; - const int stride = H*W*C; - - // Get pointer to my blocks image - int step = 0; - if (N > 0) { - step = n * stride; - } - */ for (int c = 0; c < C; ++c) { for (int h = threadIdx.y; h < H; h += blockDim.y) { for (int w = threadIdx.x; w < W; w += blockDim.x) { @@ -60,12 +54,35 @@ __global__ void ToTensorCudaKernel(const Tensor input, } } +// ToTensor Kernel for 4D input +template +__global__ void ToTensorCudaKernel(const Tensor input, + const Tensor output, + const int req, + const int N, + const int H, + const int W, + const int C, + const float normalize_factor) { + // We process one image per thread block. + const int n = blockIdx.x; + + for (int c = 0; c < C; ++c) { + for (int h = threadIdx.y; h < H; h += blockDim.y) { + for (int w = threadIdx.x; w < W; w += blockDim.x) { + KERNEL_ASSIGN(output[n][c][h][w], req, + input[n][h][w][c] / normalize_factor); + } + } + } +} + template void ToTensorImplCUDA(mshadow::Stream *s, const T1 input, const T2 output, const int req, - const float normalize_factor = 255.0f) { + const float normalize_factor) { int blocks, H, W, C, N; cudaStream_t stream = mshadow::Stream::GetStream(s); if (std::is_same>::value) { @@ -75,22 +92,23 @@ void ToTensorImplCUDA(mshadow::Stream *s, W = input.size(1); C = input.size(2); blocks = 1; - } /*else { + } else { // 4D Input - (N, H, W, C) - N = input.size()[0]; - H = input.size()[1]; - W = input.size()[2]; - C = input.size()[3]; - // blocks = N > 0 ? N : 1; + N = input.size(0); + H = input.size(1); + W = input.size(2); + C = input.size(3); + blocks = N > 0 ? N : 1; blocks = N; - }*/ + } // One block per image. // Number of threads = (32, 32) is optimal, because, // computation is minimal and overhead of CUDA preparing // all threads is minimal. - ToTensorCudaKernel<<>>(input, - output, req, N, H, W, C, normalize_factor); - MSHADOW_CUDA_POST_KERNEL_CHECK(ToTensorCudaKernel); + ToTensorCudaKernel + <<>>(input, output, + req, N, H, W, C, normalize_factor); + MSHADOW_CUDA_POST_KERNEL_CHECK(ToTensorCudaKernel); } NNVM_REGISTER_OP(_image_to_tensor) From caa489f78f18969789bec2470df427b1a3d7e583 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Sat, 9 Feb 2019 20:44:41 -0800 Subject: [PATCH 04/10] Fix failing CPU tests for totensor --- src/operator/image/image_random-inl.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index c2214461308c..3401d319619e 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -93,7 +93,7 @@ inline void ToTensor(float* out_data, const DType* in_data, const int length, const int channels, const float normalize_factor, - const int step = 0) { + const int step) { #pragma omp parallel for collapse(2) for (int c = 0; c < channels; ++c) { for (int i = 0; i < length; ++i) { @@ -109,12 +109,13 @@ inline void ToTensorImpl(const std::vector &inputs, const int length, const int channel, const float normalize_factor, - const int step = 0) { + const int step) { MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { float* output = outputs[0].dptr(); DType* input = inputs[0].dptr(); - ToTensor(output, input, length, channel, step); + ToTensor(output, input, length, channel, + normalize_factor, step); }); }); } @@ -159,8 +160,9 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs, // 3D Input - (h, w, c) const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; const int channel = static_cast(inputs[0].shape_[2]); + const int step = 0; ToTensorImpl(inputs, outputs, req, length, - channel, normalize_factor); + channel, normalize_factor, step); } else if (inputs[0].ndim() == 4) { // 4D input (n, h, w, c) const int batch_size = inputs[0].shape_[0]; From d58f50c28efe42698163bfd3b82f48c8bcd20571 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Sat, 9 Feb 2019 22:59:18 -0800 Subject: [PATCH 05/10] disable warning on windows --- src/operator/image/image_random-inl.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 3401d319619e..cc10cab95109 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -25,6 +25,9 @@ #ifndef MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_ #define MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_ +#ifdef _MSC_VER + #pragma warning(disable:4503) // disable warning: decorated name length exceeded. +#endif #include #include From 41b7a8c44b4eca3e366241cb8ae4b14b3587915b Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Sat, 9 Feb 2019 23:06:28 -0800 Subject: [PATCH 06/10] try fix in instance norm windows build failure --- src/operator/image/image_random-inl.h | 3 --- src/operator/instance_norm-inl.h | 4 ++++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index cc10cab95109..3401d319619e 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -25,9 +25,6 @@ #ifndef MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_ #define MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_ -#ifdef _MSC_VER - #pragma warning(disable:4503) // disable warning: decorated name length exceeded. -#endif #include #include diff --git a/src/operator/instance_norm-inl.h b/src/operator/instance_norm-inl.h index 258c164450d0..361d7fffe7ae 100644 --- a/src/operator/instance_norm-inl.h +++ b/src/operator/instance_norm-inl.h @@ -36,6 +36,10 @@ #include "./operator_common.h" #include "./mshadow_op.h" +#ifdef _MSC_VER + #pragma warning(disable:4503) // disable warning: decorated name length exceeded. +#endif + namespace mxnet { namespace op { From 5b8732ae4a0ce18dcb25166532d942ac5d41b114 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Sat, 9 Feb 2019 23:22:09 -0800 Subject: [PATCH 07/10] Guard omp parallel collapse for windows --- src/operator/image/image_random-inl.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 3401d319619e..b6ccdb2617f3 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -94,7 +94,12 @@ inline void ToTensor(float* out_data, const DType* in_data, const int channels, const float normalize_factor, const int step) { - #pragma omp parallel for collapse(2) + // Visual C++ compiler does not support omp collapse + #ifdef _MSC_VER + #pragma omp parallel for + #else + #pragma omp parallel for collapse(2) + #endif for (int c = 0; c < channels; ++c) { for (int i = 0; i < length; ++i) { KERNEL_ASSIGN(out_data[step + c*length + i], req, From 9562f57c20384a93108e0a98719b5f72d9cce5d6 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Sun, 10 Feb 2019 07:26:07 -0800 Subject: [PATCH 08/10] Remove warning supression to check if it is ok --- src/operator/image/image_random-inl.h | 4 ++-- src/operator/instance_norm-inl.h | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index b6ccdb2617f3..9cab7ef570ed 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -94,12 +94,12 @@ inline void ToTensor(float* out_data, const DType* in_data, const int channels, const float normalize_factor, const int step) { - // Visual C++ compiler does not support omp collapse + // Microsoft Visual C++ compiler does not support omp collapse #ifdef _MSC_VER #pragma omp parallel for #else #pragma omp parallel for collapse(2) - #endif + #endif // _MSC_VER for (int c = 0; c < channels; ++c) { for (int i = 0; i < length; ++i) { KERNEL_ASSIGN(out_data[step + c*length + i], req, diff --git a/src/operator/instance_norm-inl.h b/src/operator/instance_norm-inl.h index 361d7fffe7ae..258c164450d0 100644 --- a/src/operator/instance_norm-inl.h +++ b/src/operator/instance_norm-inl.h @@ -36,10 +36,6 @@ #include "./operator_common.h" #include "./mshadow_op.h" -#ifdef _MSC_VER - #pragma warning(disable:4503) // disable warning: decorated name length exceeded. -#endif - namespace mxnet { namespace op { From fc630eb5e1ab2ba0398babe2fd501302c6fbee9e Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Sun, 10 Feb 2019 07:52:47 -0800 Subject: [PATCH 09/10] fix lint issues --- src/operator/image/image_random-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 9cab7ef570ed..f17a10342cbe 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -99,7 +99,7 @@ inline void ToTensor(float* out_data, const DType* in_data, #pragma omp parallel for #else #pragma omp parallel for collapse(2) - #endif // _MSC_VER + #endif // _MSC_VER for (int c = 0; c < channels; ++c) { for (int i = 0; i < length; ++i) { KERNEL_ASSIGN(out_data[step + c*length + i], req, From 58c68011e26645cb31f59ad12f32fa1b5e6db8a9 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Sun, 10 Feb 2019 22:23:20 -0800 Subject: [PATCH 10/10] Address code review comments --- src/operator/image/image_random-inl.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index f17a10342cbe..392fff4dbf81 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -160,6 +160,8 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs, } }); }); + #else + LOG(FATAL) << "Compile with USE_CUDA=1 to use ToTensor operator on GPU."; #endif // MXNET_USE_CUDA } else if (inputs[0].ndim() == 3) { // 3D Input - (h, w, c)