From f0bfbfa6d59c0362a6cc58749b1af4eeb4717da9 Mon Sep 17 00:00:00 2001 From: ziheng Date: Sun, 30 Apr 2017 17:34:08 -0700 Subject: [PATCH] Change to CUDNN_CALL (#6048) * Change to CUDNN_CALL * Fix lint --- include/mxnet/operator.h | 2 +- src/operator/cudnn_activation-inl.h | 104 ++-- src/operator/cudnn_batch_norm-inl.h | 163 +++---- src/operator/cudnn_bilinear_sampler-inl.h | 97 ++-- src/operator/cudnn_deconvolution-inl.h | 482 +++++++++---------- src/operator/cudnn_lrn-inl.h | 76 +-- src/operator/cudnn_pooling-inl.h | 170 ++++--- src/operator/cudnn_rnn-inl.h | 418 ++++++++-------- src/operator/cudnn_softmax_activation-inl.h | 58 +-- src/operator/cudnn_spatial_transformer-inl.h | 112 ++--- src/operator/operator_common.h | 1 + 11 files changed, 836 insertions(+), 847 deletions(-) diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 41cc4b17b897..02bcdcb60c04 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -64,7 +64,7 @@ struct OpContext { /*! * \brief Operator interface. - * Operator defins basic operation unit of optimized computation graph in mxnet. + * Operator defines basic operation unit of optimized computation graph in mxnet. * This interface relies on pre-allocated memory in TBlob, the caller need to set * the memory region in TBlob correctly before calling Forward and Backward. * diff --git a/src/operator/cudnn_activation-inl.h b/src/operator/cudnn_activation-inl.h index 96488ef30e45..68f68b6225be 100644 --- a/src/operator/cudnn_activation-inl.h +++ b/src/operator/cudnn_activation-inl.h @@ -36,18 +36,16 @@ class CuDNNActivationOp : public Operator { } #if CUDNN_MAJOR >= 5 nan_prop_ = CUDNN_NOT_PROPAGATE_NAN; - CHECK_EQ(cudnnCreateActivationDescriptor(&desc_), - CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetActivationDescriptor(desc_, mode_, nan_prop_, relu_ceil_), - CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateActivationDescriptor(&desc_)); + CUDNN_CALL(cudnnSetActivationDescriptor(desc_, mode_, nan_prop_, relu_ceil_)); #endif } ~CuDNNActivationOp() { if (init_cudnn_) { - CHECK_EQ(cudnnDestroyTensorDescriptor(shape_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc_)); #if CUDNN_MAJOR >= 5 - CHECK_EQ(cudnnDestroyActivationDescriptor(desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnDestroyActivationDescriptor(desc_)); #endif } } @@ -89,33 +87,33 @@ class CuDNNActivationOp : public Operator { CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); if (!init_cudnn_) { init_cudnn_ = true; - CHECK_EQ(cudnnCreateTensorDescriptor(&shape_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptor(shape_desc_, - CUDNN_TENSOR_NCHW, - dtype_, - data.shape_[0], - data.shape_[1], - data.shape_[2], - data.shape_[3]), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc_)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(shape_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + data.shape_[0], + data.shape_[1], + data.shape_[2], + data.shape_[3])); } #if CUDNN_MAJOR <= 4 - CHECK_EQ(cudnnActivationForward(s->dnn_handle_, - mode_, - &alpha, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - out.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnActivationForward(s->dnn_handle_, + mode_, + &alpha, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + out.dptr_)); #elif CUDNN_MAJOR >= 5 - CHECK_EQ(cudnnActivationForward(s->dnn_handle_, + CUDNN_CALL(cudnnActivationForward(s->dnn_handle_, desc_, - &alpha, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - out.dptr_), CUDNN_STATUS_SUCCESS); + &alpha, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + out.dptr_)); #endif } @@ -166,31 +164,31 @@ class CuDNNActivationOp : public Operator { } CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); #if CUDNN_MAJOR <= 4 - CHECK_EQ(cudnnActivationBackward(s->dnn_handle_, - mode_, - &alpha, - shape_desc_, - output_data.dptr_, - shape_desc_, - grad.dptr_, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - input_grad.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnActivationBackward(s->dnn_handle_, + mode_, + &alpha, + shape_desc_, + output_data.dptr_, + shape_desc_, + grad.dptr_, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + input_grad.dptr_)); #elif CUDNN_MAJOR >= 5 - CHECK_EQ(cudnnActivationBackward(s->dnn_handle_, - desc_, - &alpha, - shape_desc_, - output_data.dptr_, - shape_desc_, - grad.dptr_, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - input_grad.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnActivationBackward(s->dnn_handle_, + desc_, + &alpha, + shape_desc_, + output_data.dptr_, + shape_desc_, + grad.dptr_, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + input_grad.dptr_)); #endif } diff --git a/src/operator/cudnn_batch_norm-inl.h b/src/operator/cudnn_batch_norm-inl.h index 596944bcaf08..b917d95a6bfe 100755 --- a/src/operator/cudnn_batch_norm-inl.h +++ b/src/operator/cudnn_batch_norm-inl.h @@ -40,8 +40,8 @@ class CuDNNBatchNormOp : public Operator { ~CuDNNBatchNormOp() { if (init_cudnn_) { - CHECK_EQ(cudnnDestroyTensorDescriptor(io_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(mean_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnDestroyTensorDescriptor(io_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(mean_desc_)); } } @@ -73,18 +73,18 @@ class CuDNNBatchNormOp : public Operator { shape_[i] = 1; } } - CHECK_EQ(cudnnCreateTensorDescriptor(&io_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&mean_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptor(io_desc_, - CUDNN_TENSOR_NCHW, - dtype_, - shape_[0], - shape_[1], - shape_[2], - shape_[3]), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDeriveBNTensorDescriptor(mean_desc_, - io_desc_, - CUDNN_BATCHNORM_SPATIAL), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc_)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + shape_[0], + shape_[1], + shape_[2], + shape_[3])); + CUDNN_CALL(cudnnDeriveBNTensorDescriptor(mean_desc_, + io_desc_, + CUDNN_BATCHNORM_SPATIAL)); init_cudnn_ = true; } @@ -117,38 +117,38 @@ class CuDNNBatchNormOp : public Operator { Tensor save_inv_var = out_data[cudnnbatchnorm::kInvVar] .get_with_shape(Shape1(shape_[1]), s); - CHECK_EQ(cudnnBatchNormalizationForwardTraining(s->dnn_handle_, - CUDNN_BATCHNORM_SPATIAL, - &a, - &b, - io_desc_, - x.dptr_, - io_desc_, - y.dptr_, - mean_desc_, - gamma.dptr_, - beta.dptr_, - 1 - param_.momentum, - moving_mean.dptr_, - moving_inv_var.dptr_, - param_.eps, - save_mean.dptr_, - save_inv_var.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_, + CUDNN_BATCHNORM_SPATIAL, + &a, + &b, + io_desc_, + x.dptr_, + io_desc_, + y.dptr_, + mean_desc_, + gamma.dptr_, + beta.dptr_, + 1 - param_.momentum, + moving_mean.dptr_, + moving_inv_var.dptr_, + param_.eps, + save_mean.dptr_, + save_inv_var.dptr_)); } else { - CHECK_EQ(cudnnBatchNormalizationForwardInference(s->dnn_handle_, - CUDNN_BATCHNORM_SPATIAL, - &a, - &b, - io_desc_, - x.dptr_, - io_desc_, - y.dptr_, - mean_desc_, - gamma.dptr_, - beta.dptr_, - moving_mean.dptr_, - moving_inv_var.dptr_, - param_.eps), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnBatchNormalizationForwardInference(s->dnn_handle_, + CUDNN_BATCHNORM_SPATIAL, + &a, + &b, + io_desc_, + x.dptr_, + io_desc_, + y.dptr_, + mean_desc_, + gamma.dptr_, + beta.dptr_, + moving_mean.dptr_, + moving_inv_var.dptr_, + param_.eps)); } }) } @@ -197,25 +197,26 @@ class CuDNNBatchNormOp : public Operator { if (param_.fix_gamma) gamma = 1.f; - CHECK_EQ(cudnnBatchNormalizationBackward(s->dnn_handle_, - CUDNN_BATCHNORM_SPATIAL, - &a, - &b, - &a, - req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add, - io_desc_, - x.dptr_, - io_desc_, - dy.dptr_, - io_desc_, - dx.dptr_, - mean_desc_, - gamma.dptr_, - dgamma.dptr_, - dbeta.dptr_, - param_.eps, - save_mean.dptr_, - save_inv_var.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnBatchNormalizationBackward( + s->dnn_handle_, + CUDNN_BATCHNORM_SPATIAL, + &a, + &b, + &a, + req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add, + io_desc_, + x.dptr_, + io_desc_, + dy.dptr_, + io_desc_, + dx.dptr_, + mean_desc_, + gamma.dptr_, + dgamma.dptr_, + dbeta.dptr_, + param_.eps, + save_mean.dptr_, + save_inv_var.dptr_)); if (param_.fix_gamma) dgamma = 0.f; }) #else // CUDNN_VERSION < 4007 @@ -237,23 +238,23 @@ class CuDNNBatchNormOp : public Operator { CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); if (param_.fix_gamma) gamma = 1.f; - CHECK_EQ(cudnnBatchNormalizationBackward(s->dnn_handle_, - CUDNN_BATCHNORM_SPATIAL, - &a, - &b, - io_desc_, - x.dptr_, - io_desc_, - dy.dptr_, - io_desc_, - dx.dptr_, - mean_desc_, - gamma.dptr_, - dgamma.dptr_, - dbeta.dptr_, - param_.eps, - save_mean.dptr_, - save_inv_var.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnBatchNormalizationBackward(s->dnn_handle_, + CUDNN_BATCHNORM_SPATIAL, + &a, + &b, + io_desc_, + x.dptr_, + io_desc_, + dy.dptr_, + io_desc_, + dx.dptr_, + mean_desc_, + gamma.dptr_, + dgamma.dptr_, + dbeta.dptr_, + param_.eps, + save_mean.dptr_, + save_inv_var.dptr_)); if (param_.fix_gamma) dgamma = 0.f; }) #endif diff --git a/src/operator/cudnn_bilinear_sampler-inl.h b/src/operator/cudnn_bilinear_sampler-inl.h index e5cd81dbe256..8b012b71723b 100644 --- a/src/operator/cudnn_bilinear_sampler-inl.h +++ b/src/operator/cudnn_bilinear_sampler-inl.h @@ -25,9 +25,9 @@ class CuDNNBilinearSamplerOp : public Operator { ~CuDNNBilinearSamplerOp() { if (init_cudnn_) { - CHECK_EQ(cudnnDestroySpatialTransformerDescriptor(st_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(in_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(out_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnDestroySpatialTransformerDescriptor(st_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(in_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(out_desc_)); } } @@ -56,15 +56,15 @@ class CuDNNBilinearSamplerOp : public Operator { CHECK_EQ(grid_tmp.CheckContiguous(), true); typename DataType::ScaleType alpha = 1.0f; typename DataType::ScaleType beta = 0.0f; - CHECK_EQ(cudnnSpatialTfSamplerForward(s->dnn_handle_, - st_desc_, - &alpha, - in_desc_, - data.dptr_, - grid_tmp.dptr_, - &beta, - out_desc_, - out.dptr_/*output*/), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSpatialTfSamplerForward(s->dnn_handle_, + st_desc_, + &alpha, + in_desc_, + data.dptr_, + grid_tmp.dptr_, + &beta, + out_desc_, + out.dptr_)); } virtual void Backward(const OpContext &ctx, @@ -91,21 +91,20 @@ class CuDNNBilinearSamplerOp : public Operator { typename DataType::ScaleType beta = (req[bs::kData] == kAddTo) ? 1.0f : 0.0f; typename DataType::ScaleType alpha_dgrid = 1.0f; typename DataType::ScaleType beta_dgrid = 0.0f; - CHECK_EQ(cudnnSpatialTfSamplerBackward(s->dnn_handle_, - st_desc_, - &alpha, - in_desc_, - data.dptr_, - &beta, - in_desc_/*reuse in_desc_*/, - gdata.dptr_/*output*/, - &alpha_dgrid, - out_desc_/*reuse out_desc_*/, - grad.dptr_, - grid_tmp.dptr_, - &beta_dgrid, - grid_tmp.dptr_/*output, reuse grid*/), - CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSpatialTfSamplerBackward(s->dnn_handle_, + st_desc_, + &alpha, + in_desc_, + data.dptr_, + &beta, + in_desc_/*reuse in_desc_*/, + gdata.dptr_/*output*/, + &alpha_dgrid, + out_desc_/*reuse out_desc_*/, + grad.dptr_, + grid_tmp.dptr_, + &beta_dgrid, + grid_tmp.dptr_)); Assign(ggrid, req[bs::kGrid], transpose(grid_tmp, Shape4(0, 3, 1, 2))); } @@ -123,30 +122,30 @@ class CuDNNBilinearSamplerOp : public Operator { init_cudnn_ = true; Tensor data = in_data[bs::kData].get(s); Tensor out = out_data[bs::kOut].get(s); - CHECK_EQ(cudnnCreateSpatialTransformerDescriptor(&st_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptor(in_desc_, - format_, - dtype_, - data.size(0), - data.size(1), - data.size(2), - data.size(3)), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptor(out_desc_, - format_, - dtype_, - out.size(0), - out.size(1), - out.size(2), - out.size(3)), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateSpatialTransformerDescriptor(&st_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc_, + format_, + dtype_, + data.size(0), + data.size(1), + data.size(2), + data.size(3))); + CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc_, + format_, + dtype_, + out.size(0), + out.size(1), + out.size(2), + out.size(3))); int dim[] = {static_cast(out.size(0)), static_cast(out.size(1)), static_cast(out.size(2)), static_cast(out.size(3))}; - CHECK_EQ(cudnnSetSpatialTransformerNdDescriptor(st_desc_, - sampler_, - dtype_, - 4, - dim) , CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetSpatialTransformerNdDescriptor(st_desc_, + sampler_, + dtype_, + 4, + dim)); } } diff --git a/src/operator/cudnn_deconvolution-inl.h b/src/operator/cudnn_deconvolution-inl.h index 893b2b49c9d3..8405c2399897 100644 --- a/src/operator/cudnn_deconvolution-inl.h +++ b/src/operator/cudnn_deconvolution-inl.h @@ -68,12 +68,12 @@ class CuDNNDeconvolutionOp : public Operator { ~CuDNNDeconvolutionOp() { if (init_cudnn_) { - CHECK_EQ(cudnnDestroyTensorDescriptor(in_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(out_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(bias_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyFilterDescriptor(filter_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyConvolutionDescriptor(forward_conv_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyConvolutionDescriptor(backward_conv_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnDestroyTensorDescriptor(in_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(out_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(bias_desc_)); + CUDNN_CALL(cudnnDestroyFilterDescriptor(filter_desc_)); + CUDNN_CALL(cudnnDestroyConvolutionDescriptor(forward_conv_desc_)); + CUDNN_CALL(cudnnDestroyConvolutionDescriptor(backward_conv_desc_)); } } @@ -121,55 +121,55 @@ class CuDNNDeconvolutionOp : public Operator { typename DataType::ScaleType alpha = 1.0f; typename DataType::ScaleType beta = 0.0f; #if CUDNN_MAJOR <= 4 - CHECK_EQ(cudnnConvolutionBackwardData_v3(s->dnn_handle_, - &alpha, - filter_desc_, - wmat_ptr + weight_offset_ * g, - in_desc_, - data_ptr + data_offset_ * g, - forward_conv_desc_, // this backward algorithm used for inference - back_algo_, - workspace.dptr_, - backward_workspace_byte_, - &beta, - out_desc_, - out.dptr_ + out_offset_ * g), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnConvolutionBackwardData_v3(s->dnn_handle_, + &alpha, + filter_desc_, + wmat_ptr + weight_offset_ * g, + in_desc_, + data_ptr + data_offset_ * g, + forward_conv_desc_, // this backward algorithm used for inference + back_algo_, + workspace.dptr_, + backward_workspace_byte_, + &beta, + out_desc_, + out.dptr_ + out_offset_ * g)); #elif CUDNN_MAJOR >= 5 - CHECK_EQ(cudnnConvolutionBackwardData(s->dnn_handle_, - &alpha, - filter_desc_, - wmat_ptr + weight_offset_ * g, - in_desc_, - data_ptr + data_offset_ * g, - forward_conv_desc_, // this backward algorithm used for inference - back_algo_, - workspace.dptr_, - backward_workspace_byte_, - &beta, - out_desc_, - out_ptr + out_offset_ * g), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnConvolutionBackwardData(s->dnn_handle_, + &alpha, + filter_desc_, + wmat_ptr + weight_offset_ * g, + in_desc_, + data_ptr + data_offset_ * g, + forward_conv_desc_, // this backward algorithm used for inference + back_algo_, + workspace.dptr_, + backward_workspace_byte_, + &beta, + out_desc_, + out_ptr + out_offset_ * g)); #endif if (!param_.no_bias) { beta = 1.0f; Tensor bias = in_data[deconv::kBias].get(s); #if CUDNN_MAJOR >= 4 - CHECK_EQ(cudnnAddTensor(s->dnn_handle_, - &alpha, - bias_desc_, - bias.dptr_ + bias_offset_ * g, - &beta, - out_desc_, - out_ptr + out_offset_ * g), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnAddTensor(s->dnn_handle_, + &alpha, + bias_desc_, + bias.dptr_ + bias_offset_ * g, + &beta, + out_desc_, + out_ptr + out_offset_ * g)); #endif #if CUDNN_MAJOR == 3 - CHECK_EQ(cudnnAddTensor(s->dnn_handle_, - CUDNN_ADD_SAME_C, - &alpha, - bias_desc_, - bias.dptr_ + bias_offset_ * g, - &beta, - out_desc_, - out_ptr + out_offset_ * g), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnAddTensor(s->dnn_handle_, + CUDNN_ADD_SAME_C, + &alpha, + bias_desc_, + bias.dptr_ + bias_offset_ * g, + &beta, + out_desc_, + out_ptr + out_offset_ * g)); #endif } } @@ -232,60 +232,61 @@ class CuDNNDeconvolutionOp : public Operator { req[deconv::kWeight] == kAddTo ? 1.0f : 0.0f; if (!param_.no_bias && (req[deconv::kBias] != kNullOp)) { Tensor gbias = in_grad[deconv::kBias].get(s); - CHECK_EQ(cudnnConvolutionBackwardBias(s->dnn_handle_, - &alpha, - out_desc_, - grad_ptr + out_offset_ * g, - &bias_beta, - bias_desc_, - gbias.dptr_ + bias_offset_ * g), - CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnConvolutionBackwardBias(s->dnn_handle_, + &alpha, + out_desc_, + grad_ptr + out_offset_ * g, + &bias_beta, + bias_desc_, + gbias.dptr_ + bias_offset_ * g)); } if (req[deconv::kWeight] != kNullOp) { #if CUDNN_MAJOR <= 4 - CHECK_EQ(cudnnConvolutionBackwardFilter_v3(s->dnn_handle_, - &alpha, - out_desc_, - grad_ptr + out_offset_ * g, - in_desc_, - data_ptr + data_offset_ * g, - backward_conv_desc_, - back_algo_w_, - workspace.dptr_, - backward_workspace_byte_, - &weight_beta, - filter_desc_, - gwmat.dptr_ + weight_offset_ * g), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnConvolutionBackwardFilter_v3( + s->dnn_handle_, + &alpha, + out_desc_, + grad_ptr + out_offset_ * g, + in_desc_, + data_ptr + data_offset_ * g, + backward_conv_desc_, + back_algo_w_, + workspace.dptr_, + backward_workspace_byte_, + &weight_beta, + filter_desc_, + gwmat.dptr_ + weight_offset_ * g)); #elif CUDNN_MAJOR >= 5 - CHECK_EQ(cudnnConvolutionBackwardFilter(s->dnn_handle_, - &alpha, - out_desc_, - grad_ptr + out_offset_ * g, - in_desc_, - data_ptr + data_offset_ * g, - backward_conv_desc_, - back_algo_w_, - workspace.dptr_, - backward_workspace_byte_, - &weight_beta, - filter_desc_, - gwmat_ptr + weight_offset_ * g), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnConvolutionBackwardFilter( + s->dnn_handle_, + &alpha, + out_desc_, + grad_ptr + out_offset_ * g, + in_desc_, + data_ptr + data_offset_ * g, + backward_conv_desc_, + back_algo_w_, + workspace.dptr_, + backward_workspace_byte_, + &weight_beta, + filter_desc_, + gwmat_ptr + weight_offset_ * g)); #endif } if (req[deconv::kData] != kNullOp) { - CHECK_EQ(cudnnConvolutionForward(s->dnn_handle_, - &alpha, - out_desc_, - grad_ptr + out_offset_ * g, - filter_desc_, - wmat_ptr + weight_offset_ * g, - backward_conv_desc_, // fwd alg used to backprop-to-data - algo_, - workspace.dptr_, - forward_workspace_byte_, - &data_beta, - in_desc_, - gdata_ptr + data_offset_ * g), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnConvolutionForward(s->dnn_handle_, + &alpha, + out_desc_, + grad_ptr + out_offset_ * g, + filter_desc_, + wmat_ptr + weight_offset_ * g, + backward_conv_desc_, + algo_, + workspace.dptr_, + forward_workspace_byte_, + &data_beta, + in_desc_, + gdata_ptr + data_offset_ * g)); } } } @@ -347,12 +348,12 @@ class CuDNNDeconvolutionOp : public Operator { size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_shape.size(), expected); CHECK_EQ(out_shape.size(), 1U); - CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&bias_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateFilterDescriptor(&filter_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateConvolutionDescriptor(&forward_conv_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateConvolutionDescriptor(&backward_conv_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&bias_desc_)); + CUDNN_CALL(cudnnCreateFilterDescriptor(&filter_desc_)); + CUDNN_CALL(cudnnCreateConvolutionDescriptor(&forward_conv_desc_)); + CUDNN_CALL(cudnnCreateConvolutionDescriptor(&backward_conv_desc_)); TShape dshape = in_shape[deconv::kData]; TShape wshape = in_shape[deconv::kWeight]; @@ -367,64 +368,60 @@ class CuDNNDeconvolutionOp : public Operator { param_.InferPad(dshape, o_pad, o_adj); #if CUDNN_MAJOR >= 6 - CHECK_EQ(cudnnSetConvolution2dDescriptor(forward_conv_desc_, - o_pad[0], - o_pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], - CUDNN_CROSS_CORRELATION, - cudnn_forward_compute_type), - CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetConvolution2dDescriptor(backward_conv_desc_, - o_pad[0], - o_pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], - CUDNN_CROSS_CORRELATION, - cudnn_backward_compute_type), - CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetConvolution2dDescriptor(forward_conv_desc_, + o_pad[0], + o_pad[1], + param_.stride[0], + param_.stride[1], + param_.dilate[0], + param_.dilate[1], + CUDNN_CROSS_CORRELATION, + cudnn_forward_compute_type)); + CUDNN_CALL(cudnnSetConvolution2dDescriptor(backward_conv_desc_, + o_pad[0], + o_pad[1], + param_.stride[0], + param_.stride[1], + param_.dilate[0], + param_.dilate[1], + CUDNN_CROSS_CORRELATION, + cudnn_backward_compute_type)); #else - CHECK_EQ(cudnnSetConvolution2dDescriptor(forward_conv_desc_, - o_pad[0], - o_pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], - CUDNN_CROSS_CORRELATION), - CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetConvolution2dDescriptor(backward_conv_desc_, - o_pad[0], - o_pad[1], - param_.stride[0], - param_.stride[1], - param_.dilate[0], - param_.dilate[1], - CUDNN_CROSS_CORRELATION), - CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetConvolution2dDescriptor(forward_conv_desc_, + o_pad[0], + o_pad[1], + param_.stride[0], + param_.stride[1], + param_.dilate[0], + param_.dilate[1], + CUDNN_CROSS_CORRELATION)); + CUDNN_CALL(cudnnSetConvolution2dDescriptor(backward_conv_desc_, + o_pad[0], + o_pad[1], + param_.stride[0], + param_.stride[1], + param_.dilate[0], + param_.dilate[1], + CUDNN_CROSS_CORRELATION)); #endif #if CUDNN_MAJOR >= 5 wshape = ConvertLayout(wshape.get<4>(), param_.layout.value(), kNCHW); - CHECK_EQ(cudnnSetFilter4dDescriptor(filter_desc_, - dtype_, - format_, - wshape[0], - wshape[1], - wshape[2], - wshape[3]), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetFilter4dDescriptor(filter_desc_, + dtype_, + format_, + wshape[0], + wshape[1], + wshape[2], + wshape[3])); #else CHECK_EQ(param_.layout.value(), kNCHW) << "CuDNN V4 only support NCHW layout"; - CHECK_EQ(cudnnSetFilter4dDescriptor(filter_desc_, - dtype_, - wshape[0], - wshape[1], - wshape[2], - wshape[3]), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetFilter4dDescriptor(filter_desc_, + dtype_, + wshape[0], + wshape[1], + wshape[2], + wshape[3])); #endif dstride = ConvertLayout(Shape4(dshape[1] * dshape[2] * dshape[3], @@ -448,32 +445,29 @@ class CuDNNDeconvolutionOp : public Operator { #if CUDNN_MAJOR >= 5 CHECK_EQ(param_.layout.value(), kNCDHW) << "CuDNN only support 3D conv with NCDHW layout"; - CHECK_EQ(cudnnSetFilterNdDescriptor(filter_desc_, - dtype_, - CUDNN_TENSOR_NCHW, - static_cast(wshape.ndim()), - reinterpret_cast(&wshape[0])), - CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetFilterNdDescriptor(filter_desc_, + dtype_, + CUDNN_TENSOR_NCHW, + static_cast(wshape.ndim()), + reinterpret_cast(&wshape[0]))); #else LOG(FATAL) << "Only support CUDNN V5 for 3D convolution"; #endif - CHECK_EQ(cudnnSetConvolutionNdDescriptor(forward_conv_desc_, - 3, - reinterpret_cast(&o_pad[0]), - reinterpret_cast(¶m_.stride[0]), - reinterpret_cast(¶m_.dilate[0]), - CUDNN_CROSS_CORRELATION, - cudnn_forward_compute_type), - CUDNN_STATUS_SUCCESS); - - CHECK_EQ(cudnnSetConvolutionNdDescriptor(backward_conv_desc_, - 3, - reinterpret_cast(&o_pad[0]), - reinterpret_cast(¶m_.stride[0]), - reinterpret_cast(¶m_.dilate[0]), - CUDNN_CROSS_CORRELATION, - cudnn_backward_compute_type), - CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(forward_conv_desc_, + 3, + reinterpret_cast(&o_pad[0]), + reinterpret_cast(¶m_.stride[0]), + reinterpret_cast(¶m_.dilate[0]), + CUDNN_CROSS_CORRELATION, + cudnn_forward_compute_type)); + + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(backward_conv_desc_, + 3, + reinterpret_cast(&o_pad[0]), + reinterpret_cast(¶m_.stride[0]), + reinterpret_cast(¶m_.dilate[0]), + CUDNN_CROSS_CORRELATION, + cudnn_backward_compute_type)); dstride = ConvertLayout(Shape5(dshape[1] * dshape[2] * dshape[3] * dshape[4], dshape[2] * dshape[3] * dshape[4], @@ -497,19 +491,17 @@ class CuDNNDeconvolutionOp : public Operator { data_offset_ = dstride[1] * dshape[1]; out_offset_ = ostride[1] * oshape[1]; - CHECK_EQ(cudnnSetTensorNdDescriptor(in_desc_, - dtype_, - static_cast(dshape.ndim()), - reinterpret_cast(&dshape[0]), - reinterpret_cast(&dstride[0])), - CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetTensorNdDescriptor(in_desc_, + dtype_, + static_cast(dshape.ndim()), + reinterpret_cast(&dshape[0]), + reinterpret_cast(&dstride[0]))); - CHECK_EQ(cudnnSetTensorNdDescriptor(out_desc_, - dtype_, - static_cast(oshape.ndim()), - reinterpret_cast(&oshape[0]), - reinterpret_cast(&ostride[0])), - CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetTensorNdDescriptor(out_desc_, + dtype_, + static_cast(oshape.ndim()), + reinterpret_cast(&oshape[0]), + reinterpret_cast(&ostride[0]))); if (!param_.no_bias) { TShape bias = in_shape[deconv::kBias]; @@ -522,11 +514,11 @@ class CuDNNDeconvolutionOp : public Operator { bias_shape.push_back(1); bias_stride.push_back(1); } - CHECK_EQ(cudnnSetTensorNdDescriptor(bias_desc_, - dtype_, - static_cast(bias_shape.size()), - &bias_shape[0], - &bias_stride[0]), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetTensorNdDescriptor(bias_desc_, + dtype_, + static_cast(bias_shape.size()), + &bias_shape[0], + &bias_stride[0])); } init_cudnn_ = true; } @@ -553,31 +545,31 @@ class CuDNNDeconvolutionOp : public Operator { if (CUDNN_MAJOR == 6 && param_.layout.value() == mshadow::kNHWC) { algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; } else { - CHECK_EQ(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, - out_desc_, - filter_desc_, - backward_conv_desc_, // forward algorithm used to backprop-to-data - in_desc_, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &(this->algo_)), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnGetConvolutionForwardAlgorithm(s->dnn_handle_, + out_desc_, + filter_desc_, + backward_conv_desc_, // forward algorithm used to backprop-to-data + in_desc_, + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &(this->algo_))); } - CHECK_EQ(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - out_desc_, - in_desc_, - backward_conv_desc_, - filter_desc_, - CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &(this->back_algo_w_)), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - in_desc_, - forward_conv_desc_, // this backward algorithm used for inference - out_desc_, - CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, - workspace_byte, - &(this->back_algo_)), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnGetConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + out_desc_, + in_desc_, + backward_conv_desc_, + filter_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &(this->back_algo_w_))); + CUDNN_CALL(cudnnGetConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + in_desc_, + forward_conv_desc_, // this backward algorithm used for inference + out_desc_, + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_byte, + &(this->back_algo_))); } else { const int kMaxAlgos = 10; int nalgo = kMaxAlgos; @@ -589,14 +581,14 @@ class CuDNNDeconvolutionOp : public Operator { algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; } else { cudnnConvolutionFwdAlgoPerf_t fwd_algo[kMaxAlgos]; - CHECK_EQ(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, - out_desc_, - filter_desc_, - backward_conv_desc_, // forward algorithm used to backprop-to-data - in_desc_, - kMaxAlgos, - &nalgo, - fwd_algo), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnFindConvolutionForwardAlgorithm(s->dnn_handle_, + out_desc_, + filter_desc_, + backward_conv_desc_, // forward algorithm used to backprop-to-data + in_desc_, + kMaxAlgos, + &nalgo, + fwd_algo)); i = 0; while (i < nalgo && (fwd_algo[i].status != CUDNN_STATUS_SUCCESS @@ -611,14 +603,14 @@ class CuDNNDeconvolutionOp : public Operator { } cudnnConvolutionBwdFilterAlgoPerf_t bwd_filter_algo[kMaxAlgos]; - CHECK_EQ(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, - out_desc_, - in_desc_, - backward_conv_desc_, - filter_desc_, - kMaxAlgos, - &nalgo, - bwd_filter_algo), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnFindConvolutionBackwardFilterAlgorithm(s->dnn_handle_, + out_desc_, + in_desc_, + backward_conv_desc_, + filter_desc_, + kMaxAlgos, + &nalgo, + bwd_filter_algo)); i = 0; while (i < nalgo && (bwd_filter_algo[i].status != CUDNN_STATUS_SUCCESS @@ -632,14 +624,14 @@ class CuDNNDeconvolutionOp : public Operator { } cudnnConvolutionBwdDataAlgoPerf_t bwd_data_algo[kMaxAlgos]; - CHECK_EQ(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, - filter_desc_, - in_desc_, - forward_conv_desc_, // this backward algorithm used for inference - out_desc_, - kMaxAlgos, - &nalgo, - bwd_data_algo), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnFindConvolutionBackwardDataAlgorithm(s->dnn_handle_, + filter_desc_, + in_desc_, + forward_conv_desc_, // this backward algorithm used for inference + out_desc_, + kMaxAlgos, + &nalgo, + bwd_data_algo)); i = 0; while (i < nalgo && (bwd_data_algo[i].status != CUDNN_STATUS_SUCCESS @@ -663,28 +655,28 @@ class CuDNNDeconvolutionOp : public Operator { if (init_temp_size_) return; mshadow::Stream *s = ctx.get_stream(); size_t back_size = 0, back_size_w = 0; - CHECK_EQ(cudnnGetConvolutionBackwardDataWorkspaceSize(s->dnn_handle_, + CUDNN_CALL(cudnnGetConvolutionBackwardDataWorkspaceSize(s->dnn_handle_, filter_desc_, in_desc_, forward_conv_desc_, out_desc_, back_algo_, - &back_size), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnGetConvolutionBackwardFilterWorkspaceSize(s->dnn_handle_, + &back_size)); + CUDNN_CALL(cudnnGetConvolutionBackwardFilterWorkspaceSize(s->dnn_handle_, out_desc_, in_desc_, backward_conv_desc_, filter_desc_, back_algo_w_, - &back_size_w), CUDNN_STATUS_SUCCESS); + &back_size_w)); backward_workspace_byte_ = std::max(back_size, back_size_w); - CHECK_EQ(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_, + CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(s->dnn_handle_, out_desc_, filter_desc_, backward_conv_desc_, in_desc_, algo_, - &forward_workspace_byte_), CUDNN_STATUS_SUCCESS); + &forward_workspace_byte_)); forward_workspace_ = forward_workspace_byte_ / sizeof(DType) + 1; backward_workspace_ = backward_workspace_byte_ / sizeof(DType) + 1; diff --git a/src/operator/cudnn_lrn-inl.h b/src/operator/cudnn_lrn-inl.h index 52eb1dac04e5..d65a678bc07d 100755 --- a/src/operator/cudnn_lrn-inl.h +++ b/src/operator/cudnn_lrn-inl.h @@ -23,8 +23,8 @@ class CuDNNLocalResponseNormOp : public Operator { ~CuDNNLocalResponseNormOp() { if (init_cudnn_) { - CHECK_EQ(cudnnDestroyLRNDescriptor(lrn_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(shape_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnDestroyLRNDescriptor(lrn_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc_)); } } @@ -46,15 +46,15 @@ class CuDNNLocalResponseNormOp : public Operator { this->Init(s, in_data, out_data); } CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); - CHECK_EQ(cudnnLRNCrossChannelForward(s->dnn_handle_, - lrn_desc_, - CUDNN_LRN_CROSS_CHANNEL_DIM1, - &alpha, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - out.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnLRNCrossChannelForward(s->dnn_handle_, + lrn_desc_, + CUDNN_LRN_CROSS_CHANNEL_DIM1, + &alpha, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + out.dptr_)); } virtual void Backward(const OpContext &ctx, @@ -79,19 +79,19 @@ class CuDNNLocalResponseNormOp : public Operator { Tensor output_data = out_data[lrn_enum::kOut].get(s); Tensor input_grad = in_grad[lrn_enum::kData].get(s); CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); - CHECK_EQ(cudnnLRNCrossChannelBackward(s->dnn_handle_, - lrn_desc_, - CUDNN_LRN_CROSS_CHANNEL_DIM1, - &alpha, - shape_desc_, - output_data.dptr_, - shape_desc_, - grad.dptr_, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - input_grad.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnLRNCrossChannelBackward(s->dnn_handle_, + lrn_desc_, + CUDNN_LRN_CROSS_CHANNEL_DIM1, + &alpha, + shape_desc_, + output_data.dptr_, + shape_desc_, + grad.dptr_, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + input_grad.dptr_)); } private: @@ -110,20 +110,20 @@ class CuDNNLocalResponseNormOp : public Operator { double beta = param_.beta; double lrn_k = param_.knorm; CHECK_EQ(data.shape_, out.shape_); - CHECK_EQ(cudnnCreateLRNDescriptor(&lrn_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetLRNDescriptor(lrn_desc_, - lrn_n, - alpha, - beta, - lrn_k), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&shape_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptor(shape_desc_, - CUDNN_TENSOR_NCHW, - dtype_, - data.shape_[0], - data.shape_[1], - data.shape_[2], - data.shape_[3]), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateLRNDescriptor(&lrn_desc_)); + CUDNN_CALL(cudnnSetLRNDescriptor(lrn_desc_, + lrn_n, + alpha, + beta, + lrn_k)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc_)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(shape_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + data.shape_[0], + data.shape_[1], + data.shape_[2], + data.shape_[3])); } } bool init_cudnn_; diff --git a/src/operator/cudnn_pooling-inl.h b/src/operator/cudnn_pooling-inl.h index 9087efdec64f..3c9344ec5aeb 100644 --- a/src/operator/cudnn_pooling-inl.h +++ b/src/operator/cudnn_pooling-inl.h @@ -36,9 +36,9 @@ class CuDNNPoolingOp : public Operator { ~CuDNNPoolingOp() { if (init_cudnn_) { - CHECK_EQ(cudnnDestroyTensorDescriptor(in_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(out_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyPoolingDescriptor(pooling_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnDestroyTensorDescriptor(in_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(out_desc_)); + CUDNN_CALL(cudnnDestroyPoolingDescriptor(pooling_desc_)); } } @@ -64,14 +64,14 @@ class CuDNNPoolingOp : public Operator { } CHECK_EQ(data.CheckContiguous(), true); CHECK_EQ(out.CheckContiguous(), true); - CHECK_EQ(cudnnPoolingForward(s->dnn_handle_, - pooling_desc_, - &alpha, - in_desc_, - data.dptr_, - &beta, - out_desc_, - out.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnPoolingForward(s->dnn_handle_, + pooling_desc_, + &alpha, + in_desc_, + data.dptr_, + &beta, + out_desc_, + out.dptr_)); } else if (param_.kernel.ndim() == 3) { // 3d pool Tensor data = in_data[pool_enum::kData].get(s); @@ -81,14 +81,14 @@ class CuDNNPoolingOp : public Operator { } CHECK_EQ(data.CheckContiguous(), true); CHECK_EQ(out.CheckContiguous(), true); - CHECK_EQ(cudnnPoolingForward(s->dnn_handle_, - pooling_desc_, - &alpha, - in_desc_, - data.dptr_, - &beta, - out_desc_, - out.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnPoolingForward(s->dnn_handle_, + pooling_desc_, + &alpha, + in_desc_, + data.dptr_, + &beta, + out_desc_, + out.dptr_)); } else { LOG(FATAL) << "Only support 2D or 3D pooling"; } @@ -119,36 +119,36 @@ class CuDNNPoolingOp : public Operator { Tensor m_in_data = in_data[pool_enum::kData].get(s); Tensor m_out_data = out_data[pool_enum::kOut].get(s); Tensor m_in_grad = in_grad[pool_enum::kData].get(s); - CHECK_EQ(cudnnPoolingBackward(s->dnn_handle_, - pooling_desc_, - &alpha, - out_desc_, - m_out_data.dptr_, - out_desc_, - m_out_grad.dptr_, - in_desc_, - m_in_data.dptr_, - &beta, - in_desc_, - m_in_grad.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnPoolingBackward(s->dnn_handle_, + pooling_desc_, + &alpha, + out_desc_, + m_out_data.dptr_, + out_desc_, + m_out_grad.dptr_, + in_desc_, + m_in_data.dptr_, + &beta, + in_desc_, + m_in_grad.dptr_)); } else if (param_.kernel.ndim() == 3) { // 3d pool Tensor m_out_grad = out_grad[pool_enum::kOut].get(s); Tensor m_in_data = in_data[pool_enum::kData].get(s); Tensor m_out_data = out_data[pool_enum::kOut].get(s); Tensor m_in_grad = in_grad[pool_enum::kData].get(s); - CHECK_EQ(cudnnPoolingBackward(s->dnn_handle_, - pooling_desc_, - &alpha, - out_desc_, - m_out_data.dptr_, - out_desc_, - m_out_grad.dptr_, - in_desc_, - m_in_data.dptr_, - &beta, - in_desc_, - m_in_grad.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnPoolingBackward(s->dnn_handle_, + pooling_desc_, + &alpha, + out_desc_, + m_out_data.dptr_, + out_desc_, + m_out_grad.dptr_, + in_desc_, + m_in_data.dptr_, + &beta, + in_desc_, + m_in_grad.dptr_)); } else { LOG(FATAL) << "Only support 2D or 3D pooling"; } @@ -171,25 +171,25 @@ class CuDNNPoolingOp : public Operator { Tensor data = in_data[pool_enum::kData].get(s); Tensor out = out_data[pool_enum::kOut].get(s); mshadow::Shape<4> dshape = data.shape_; - CHECK_EQ(cudnnCreatePoolingDescriptor(&pooling_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptor(in_desc_, - CUDNN_TENSOR_NCHW, - dtype_, - data.shape_[0], - data.shape_[1], - data.shape_[2], - data.shape_[3]), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptor(out_desc_, - CUDNN_TENSOR_NCHW, - dtype_, - out.shape_[0], - out.shape_[1], - out.shape_[2], - out.shape_[3]), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreatePoolingDescriptor(&pooling_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + data.shape_[0], + data.shape_[1], + data.shape_[2], + data.shape_[3])); + CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + out.shape_[0], + out.shape_[1], + out.shape_[2], + out.shape_[3])); #if CUDNN_MAJOR >= 5 - CHECK_EQ(cudnnSetPooling2dDescriptor(pooling_desc_, + CUDNN_CALL(cudnnSetPooling2dDescriptor(pooling_desc_, mode_, nan_prop_, param_.global_pool ? dshape[2] : param_.kernel[0], @@ -197,25 +197,23 @@ class CuDNNPoolingOp : public Operator { param_.pad[0], param_.pad[1], param_.global_pool ? 1 : param_.stride[0], - param_.global_pool ? 1 :param_.stride[1]), - CUDNN_STATUS_SUCCESS); + param_.global_pool ? 1 :param_.stride[1])); #else - CHECK_EQ(cudnnSetPooling2dDescriptor(pooling_desc_, + CUDNN_CALL(cudnnSetPooling2dDescriptor(pooling_desc_, mode_, param_.global_pool ? dshape[2] : param_.kernel[0], param_.global_pool ? dshape[3] : param_.kernel[1], param_.pad[0], param_.pad[1], param_.global_pool ? 1 : param_.stride[0], - param_.global_pool ? 1 : param_.stride[1]), - CUDNN_STATUS_SUCCESS); + param_.global_pool ? 1 : param_.stride[1])); #endif } else { Tensor data = in_data[pool_enum::kData].get(s); Tensor out = out_data[pool_enum::kOut].get(s); - CHECK_EQ(cudnnCreatePoolingDescriptor(&pooling_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreatePoolingDescriptor(&pooling_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_)); std::vector ishape = {static_cast(data.shape_[0]), static_cast(data.shape_[1]), static_cast(data.shape_[2]), @@ -255,24 +253,24 @@ class CuDNNPoolingOp : public Operator { param_.global_pool ? 1 : static_cast(param_.stride[1]), param_.global_pool ? 1 : static_cast(param_.stride[2])}; - CHECK_EQ(cudnnSetTensorNdDescriptor(in_desc_, - dtype_, - static_cast(ishape.size()), - &ishape[0], - &istride[0]), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensorNdDescriptor(out_desc_, - dtype_, - static_cast(oshape.size()), - &oshape[0], - &ostride[0]), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetTensorNdDescriptor(in_desc_, + dtype_, + static_cast(ishape.size()), + &ishape[0], + &istride[0])); + CUDNN_CALL(cudnnSetTensorNdDescriptor(out_desc_, + dtype_, + static_cast(oshape.size()), + &oshape[0], + &ostride[0])); #if CUDNN_MAJOR >= 5 - CHECK_EQ(cudnnSetPoolingNdDescriptor(pooling_desc_, - mode_, - nan_prop_, - static_cast(kernel_vec.size()), - &(kernel_vec[0]), - &(pad_vec[0]), - &(stride_vec[0])), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetPoolingNdDescriptor(pooling_desc_, + mode_, + nan_prop_, + static_cast(kernel_vec.size()), + &(kernel_vec[0]), + &(pad_vec[0]), + &(stride_vec[0]))); #else LOG(FATAL) << "3D pooling only support CUDNN v5 and abouve"; #endif diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h index e6bd1c8c3931..a4ce10edd886 100644 --- a/src/operator/cudnn_rnn-inl.h +++ b/src/operator/cudnn_rnn-inl.h @@ -56,23 +56,23 @@ class CuDNNRNNOp : public Operator { ~CuDNNRNNOp() { if (init_cudnn_) { for (size_t i = 0; i < x_desc_vec_.size(); ++i) { - CHECK_EQ(cudnnDestroyTensorDescriptor(x_desc_vec_[i]), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(y_desc_vec_[i]), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(dx_desc_vec_[i]), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(dy_desc_vec_[i]), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_vec_[i])); + CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_vec_[i])); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_vec_[i])); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_vec_[i])); } - CHECK_EQ(cudnnDestroyTensorDescriptor(hx_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(cx_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(hy_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(cy_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(dhx_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(dcx_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(dhy_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(dcy_desc_), CUDNN_STATUS_SUCCESS); - - CHECK_EQ(cudnnDestroyFilterDescriptor(w_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyRNNDescriptor(rnn_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyDropoutDescriptor(dropout_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(hy_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(cy_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dhx_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dcx_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dhy_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dcy_desc_)); + + CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc_)); + CUDNN_CALL(cudnnDestroyRNNDescriptor(rnn_desc_)); + CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_)); Storage::Get()->Free(dropout_states_); Storage::Get()->Free(reserve_space_); } @@ -124,48 +124,48 @@ class CuDNNRNNOp : public Operator { ctx.requested[rnn_enum::kTempSpace].get_space_typed( mshadow::Shape1(temp_size), s); if (ctx.is_train) { - CHECK_EQ(cudnnRNNForwardTraining(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - x_desc_vec_.data(), - x.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - w_desc_, - w.dptr_, - y_desc_vec_.data(), - y.dptr_, - hy_desc_, - hy_ptr, - cy_desc_, - cy_ptr, - temp_space.dptr_, - workspace_byte_, - reserve_space_.dptr, - reserve_space_byte_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnRNNForwardTraining(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space.dptr_, + workspace_byte_, + reserve_space_.dptr, + reserve_space_byte_)); } else { // inference mode - CHECK_EQ(cudnnRNNForwardInference(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - x_desc_vec_.data(), - x.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - w_desc_, - w.dptr_, - y_desc_vec_.data(), - y.dptr_, - hy_desc_, - hy_ptr, - cy_desc_, - cy_ptr, - temp_space.dptr_, - workspace_byte_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnRNNForwardInference(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + w_desc_, + w.dptr_, + y_desc_vec_.data(), + y.dptr_, + hy_desc_, + hy_ptr, + cy_desc_, + cy_ptr, + temp_space.dptr_, + workspace_byte_)); } } @@ -237,48 +237,48 @@ class CuDNNRNNOp : public Operator { Tensor temp_space = ctx.requested[rnn_enum::kTempSpace].get_space_typed( mshadow::Shape1(temp_size), s); - CHECK_EQ(cudnnRNNBackwardData(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - y_desc_vec_.data(), - y.dptr_, - dy_desc_vec_.data(), - dy.dptr_, - dhy_desc_, - dhy_ptr, - dcy_desc_, - dcy_ptr, - w_desc_, - w.dptr_, - hx_desc_, - hx.dptr_, - cx_desc_, - cx_ptr, - dx_desc_vec_.data(), - dx.dptr_, - dhx_desc_, - dhx.dptr_, - dcx_desc_, - dcx_ptr, - temp_space.dptr_, - workspace_byte_, - reserve_space_.dptr, - reserve_space_byte_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnRNNBackwardWeights(s->dnn_handle_, + CUDNN_CALL(cudnnRNNBackwardData(s->dnn_handle_, rnn_desc_, param_.seq_length_, - x_desc_vec_.data(), - x.dptr_, - hx_desc_, - hx.dptr_, y_desc_vec_.data(), y.dptr_, + dy_desc_vec_.data(), + dy.dptr_, + dhy_desc_, + dhy_ptr, + dcy_desc_, + dcy_ptr, + w_desc_, + w.dptr_, + hx_desc_, + hx.dptr_, + cx_desc_, + cx_ptr, + dx_desc_vec_.data(), + dx.dptr_, + dhx_desc_, + dhx.dptr_, + dcx_desc_, + dcx_ptr, temp_space.dptr_, workspace_byte_, - dw_desc_, - dw.dptr_, reserve_space_.dptr, - reserve_space_byte_), CUDNN_STATUS_SUCCESS); + reserve_space_byte_)); + CUDNN_CALL(cudnnRNNBackwardWeights(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + x.dptr_, + hx_desc_, + hx.dptr_, + y_desc_vec_.data(), + y.dptr_, + temp_space.dptr_, + workspace_byte_, + dw_desc_, + dw.dptr_, + reserve_space_.dptr, + reserve_space_byte_)); } private: @@ -313,10 +313,10 @@ class CuDNNRNNOp : public Operator { int dimA[3]; int strideA[3]; for (int i = 0; i < param_.seq_length_; i++) { - CHECK_EQ(cudnnCreateTensorDescriptor(&x_vec[i]), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&y_vec[i]), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&dx_vec[i]), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&dy_vec[i]), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateTensorDescriptor(&x_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&y_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_vec[i])); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_vec[i])); dimA[0] = param_.batch_size_; dimA[1] = param_.input_size_; @@ -327,16 +327,16 @@ class CuDNNRNNOp : public Operator { strideA[1] = dimA[2]; strideA[2] = 1; - CHECK_EQ(cudnnSetTensorNdDescriptor(x_vec[i], - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensorNdDescriptor(dx_vec[i], - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetTensorNdDescriptor(x_vec[i], + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_vec[i], + dtype_, + 3, + dimA, + strideA)); dimA[0] = param_.batch_size_; dimA[1] = param_.bidirectional ? param_.state_size * 2 : param_.state_size; dimA[2] = 1; @@ -344,16 +344,16 @@ class CuDNNRNNOp : public Operator { strideA[1] = dimA[2]; strideA[2] = 1; - CHECK_EQ(cudnnSetTensorNdDescriptor(y_vec[i], - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensorNdDescriptor(dy_vec[i], - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetTensorNdDescriptor(y_vec[i], + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_vec[i], + dtype_, + 3, + dimA, + strideA)); } x_desc_vec_ = x_vec; y_desc_vec_ = y_vec; @@ -368,117 +368,117 @@ class CuDNNRNNOp : public Operator { strideA[1] = dimA[2]; strideA[2] = 1; - CHECK_EQ(cudnnCreateTensorDescriptor(&hx_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&cx_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&hy_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&cy_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&dhx_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&dcx_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&dhy_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&dcy_desc_), CUDNN_STATUS_SUCCESS); - - CHECK_EQ(cudnnSetTensorNdDescriptor(hx_desc_, - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensorNdDescriptor(cx_desc_, - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensorNdDescriptor(hy_desc_, - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensorNdDescriptor(cy_desc_, - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensorNdDescriptor(dhx_desc_, - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensorNdDescriptor(dcx_desc_, - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensorNdDescriptor(dhy_desc_, - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensorNdDescriptor(dcy_desc_, - dtype_, - 3, - dimA, - strideA), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_)); + + CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_, + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(cx_desc_, + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(hy_desc_, + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(cy_desc_, + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhx_desc_, + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dcx_desc_, + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dhy_desc_, + dtype_, + 3, + dimA, + strideA)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dcy_desc_, + dtype_, + 3, + dimA, + strideA)); // Create Dropout descriptors - CHECK_EQ(cudnnCreateDropoutDescriptor(&dropout_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDropoutGetStatesSize(s->dnn_handle_, - &dropout_byte_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); + CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, + &dropout_byte_)); dropout_size_ = dropout_byte_ / sizeof(DType); dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU()); - CHECK_EQ(cudnnSetDropoutDescriptor(dropout_desc_, - s->dnn_handle_, - param_.p, // keep probability - dropout_states_.dptr, - dropout_byte_, - seed_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, + s->dnn_handle_, + param_.p, // keep probability + dropout_states_.dptr, + dropout_byte_, + seed_)); // RNN descriptors - CHECK_EQ(cudnnCreateRNNDescriptor(&rnn_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetRNNDescriptor(rnn_desc_, - param_.state_size, - param_.num_layers, - dropout_desc_, - input_mode_, - direction_, - mode_, - dtype_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_)); + CUDNN_CALL(cudnnSetRNNDescriptor(rnn_desc_, + param_.state_size, + param_.num_layers, + dropout_desc_, + input_mode_, + direction_, + mode_, + dtype_)); // Get temp space sizes - CHECK_EQ(cudnnGetRNNWorkspaceSize(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - x_desc_vec_.data(), - &workspace_byte_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnGetRNNTrainingReserveSize(s->dnn_handle_, - rnn_desc_, - param_.seq_length_, - x_desc_vec_.data(), - &reserve_space_byte_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnGetRNNWorkspaceSize(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + &workspace_byte_)); + CUDNN_CALL(cudnnGetRNNTrainingReserveSize(s->dnn_handle_, + rnn_desc_, + param_.seq_length_, + x_desc_vec_.data(), + &reserve_space_byte_)); workspace_size_ = workspace_byte_ / sizeof(DType); // Allocate the reserve space reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU()); // Check that number of params are correct size_t cudnn_param_size; - CHECK_EQ(cudnnGetRNNParamsSize(s->dnn_handle_, - rnn_desc_, - x_desc_vec_[0], - &cudnn_param_size, - dtype_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnGetRNNParamsSize(s->dnn_handle_, + rnn_desc_, + x_desc_vec_[0], + &cudnn_param_size, + dtype_)); CHECK_EQ(w.shape_[0] * sizeof(DType), cudnn_param_size); // Set param descriptors - CHECK_EQ(cudnnCreateFilterDescriptor(&w_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateFilterDescriptor(&dw_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_)); + CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_)); int dim_w[3] = {1, 1, 1}; dim_w[0] = w.shape_[0]; - CHECK_EQ(cudnnSetFilterNdDescriptor(w_desc_, - dtype_, - format_, - 3, - dim_w), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetFilterNdDescriptor(dw_desc_, - dtype_, - format_, - 3, - dim_w), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetFilterNdDescriptor(w_desc_, + dtype_, + format_, + 3, + dim_w)); + CUDNN_CALL(cudnnSetFilterNdDescriptor(dw_desc_, + dtype_, + format_, + 3, + dim_w)); // Query weight layout // cudnnFilterDescriptor_t m_desc; diff --git a/src/operator/cudnn_softmax_activation-inl.h b/src/operator/cudnn_softmax_activation-inl.h index d44d08394126..86c27317f923 100644 --- a/src/operator/cudnn_softmax_activation-inl.h +++ b/src/operator/cudnn_softmax_activation-inl.h @@ -23,7 +23,7 @@ class CuDNNSoftmaxActivationOp : public Operator { ~CuDNNSoftmaxActivationOp() { if (init_cudnn_) { - CHECK_EQ(cudnnDestroyTensorDescriptor(shape_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnDestroyTensorDescriptor(shape_desc_)); } } @@ -71,24 +71,24 @@ class CuDNNSoftmaxActivationOp : public Operator { CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); if (!init_cudnn_) { init_cudnn_ = true; - CHECK_EQ(cudnnCreateTensorDescriptor(&shape_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptor(shape_desc_, - CUDNN_TENSOR_NCHW, - dtype_, - data.shape_[0], - data.shape_[1], - data.shape_[2], - data.shape_[3]), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc_)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(shape_desc_, + CUDNN_TENSOR_NCHW, + dtype_, + data.shape_[0], + data.shape_[1], + data.shape_[2], + data.shape_[3])); } - CHECK_EQ(cudnnSoftmaxForward(s->dnn_handle_, - CUDNN_SOFTMAX_ACCURATE, - softmax_mode, - &alpha, - shape_desc_, - data.dptr_, - &beta, - shape_desc_, - out.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSoftmaxForward(s->dnn_handle_, + CUDNN_SOFTMAX_ACCURATE, + softmax_mode, + &alpha, + shape_desc_, + data.dptr_, + &beta, + shape_desc_, + out.dptr_)); } virtual void Backward(const OpContext &ctx, @@ -141,17 +141,17 @@ class CuDNNSoftmaxActivationOp : public Operator { softmax_mode = CUDNN_SOFTMAX_MODE_CHANNEL; } CHECK_EQ(s->dnn_handle_ownership_, mshadow::Stream::OwnHandle); - CHECK_EQ(cudnnSoftmaxBackward(s->dnn_handle_, - CUDNN_SOFTMAX_ACCURATE, - softmax_mode, - &alpha, - shape_desc_, - output_data.dptr_, - shape_desc_, - grad.dptr_, - &beta, - shape_desc_, - input_grad.dptr_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSoftmaxBackward(s->dnn_handle_, + CUDNN_SOFTMAX_ACCURATE, + softmax_mode, + &alpha, + shape_desc_, + output_data.dptr_, + shape_desc_, + grad.dptr_, + &beta, + shape_desc_, + input_grad.dptr_)); } private: diff --git a/src/operator/cudnn_spatial_transformer-inl.h b/src/operator/cudnn_spatial_transformer-inl.h index 9f48d55a2505..b25e8cebc077 100644 --- a/src/operator/cudnn_spatial_transformer-inl.h +++ b/src/operator/cudnn_spatial_transformer-inl.h @@ -27,9 +27,9 @@ class CuDNNSpatialTransformerOp : public Operator { ~CuDNNSpatialTransformerOp() { if (init_cudnn_) { - CHECK_EQ(cudnnDestroySpatialTransformerDescriptor(st_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(in_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnDestroyTensorDescriptor(out_desc_), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnDestroySpatialTransformerDescriptor(st_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(in_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(out_desc_)); } } @@ -57,20 +57,20 @@ class CuDNNSpatialTransformerOp : public Operator { typename DataType::ScaleType alpha = 1.0f; typename DataType::ScaleType beta = 0.0f; if (param_.transform_type == st::kAffine) { - CHECK_EQ(cudnnSpatialTfGridGeneratorForward(s->dnn_handle_, - st_desc_, - loc.dptr_, - grid.dptr_/*output*/), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSpatialTfGridGeneratorForward(s->dnn_handle_, + st_desc_, + loc.dptr_, + grid.dptr_)); } - CHECK_EQ(cudnnSpatialTfSamplerForward(s->dnn_handle_, - st_desc_, - &alpha, - in_desc_, - data.dptr_, - grid.dptr_, - &beta, - out_desc_, - out.dptr_/*output*/), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSpatialTfSamplerForward(s->dnn_handle_, + st_desc_, + &alpha, + in_desc_, + data.dptr_, + grid.dptr_, + &beta, + out_desc_, + out.dptr_)); } virtual void Backward(const OpContext &ctx, @@ -99,25 +99,25 @@ class CuDNNSpatialTransformerOp : public Operator { typename DataType::ScaleType beta = 0.0f; typename DataType::ScaleType alpha_dgrid = 1.0f; typename DataType::ScaleType beta_dgrid = 0.0f; - CHECK_EQ(cudnnSpatialTfSamplerBackward(s->dnn_handle_, - st_desc_, - &alpha, - in_desc_, - data.dptr_, - &beta, - in_desc_/*reuse in_desc_*/, - ddata.dptr_/*output*/, - &alpha_dgrid, - out_desc_/*reuse out_desc_*/, - grad.dptr_, - grid.dptr_, - &beta_dgrid, - grid.dptr_/*output, reuse grid*/), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSpatialTfSamplerBackward(s->dnn_handle_, + st_desc_, + &alpha, + in_desc_, + data.dptr_, + &beta, + in_desc_/*reuse in_desc_*/, + ddata.dptr_/*output*/, + &alpha_dgrid, + out_desc_/*reuse out_desc_*/, + grad.dptr_, + grid.dptr_, + &beta_dgrid, + grid.dptr_)); if (param_.transform_type == st::kAffine) { - CHECK_EQ(cudnnSpatialTfGridGeneratorBackward(s->dnn_handle_, - st_desc_, - grid.dptr_, - dloc.dptr_/*out*/), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSpatialTfGridGeneratorBackward(s->dnn_handle_, + st_desc_, + grid.dptr_, + dloc.dptr_/*out*/)); } } @@ -135,31 +135,31 @@ class CuDNNSpatialTransformerOp : public Operator { init_cudnn_ = true; Tensor data = in_data[st::kData].get(s); Tensor out = out_data[st::kOut].get(s); - CHECK_EQ(cudnnCreateSpatialTransformerDescriptor(&st_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&in_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnCreateTensorDescriptor(&out_desc_), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptor(in_desc_, - format_, - dtype_, - data.size(0), - data.size(1), - data.size(2), - data.size(3)), CUDNN_STATUS_SUCCESS); - CHECK_EQ(cudnnSetTensor4dDescriptor(out_desc_, - format_, - dtype_, - out.size(0), - out.size(1), - out.size(2), - out.size(3)), CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnCreateSpatialTransformerDescriptor(&st_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&in_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&out_desc_)); + CUDNN_CALL(cudnnSetTensor4dDescriptor(in_desc_, + format_, + dtype_, + data.size(0), + data.size(1), + data.size(2), + data.size(3))); + CUDNN_CALL(cudnnSetTensor4dDescriptor(out_desc_, + format_, + dtype_, + out.size(0), + out.size(1), + out.size(2), + out.size(3))); if (param_.sampler_type == st::kBilinear) { int dim[] = {static_cast(out.size(0)), static_cast(out.size(1)), static_cast(out.size(2)), static_cast(out.size(3))}; - CHECK_EQ(cudnnSetSpatialTransformerNdDescriptor(st_desc_, - sampler_, - dtype_, - 4, - dim) , CUDNN_STATUS_SUCCESS); + CUDNN_CALL(cudnnSetSpatialTransformerNdDescriptor(st_desc_, + sampler_, + dtype_, + 4, + dim)); } } } diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index df0a59be3aa9..a43d092bceb6 100755 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -16,6 +16,7 @@ #include #include #include +#include "../common/cuda_utils.h" namespace mxnet { namespace op {