diff --git a/python/tvm/contrib/cudnn.py b/python/tvm/contrib/cudnn.py index 2227d440126d..9b92c7cc2773 100644 --- a/python/tvm/contrib/cudnn.py +++ b/python/tvm/contrib/cudnn.py @@ -285,68 +285,6 @@ def conv_output_shape( return output -def _conv_output_shape_from_cudnn( - tensor_format, pad, stride, dilation, x_shape, w_shape, data_dtype, conv_dtype, groups=1 -): - """Get output shape of 2D or 3D convolution. The output of this - function should be identical to that of conv_output_shape, but - requires a GPU with CuDNN to be present. This is maintained for - testing purposes to validate the output of conv_output_shape. - - Paramters - --------- - tensor_format: int - 0: CUDNN_TENSOR_NCHW - 1: CUDNN_TENSOR_NHWC - 2: CUDNN_TENSOR_NCHW_VECT_C - pad: int or list - padding - stride: int or list - stride - dilation: int or list - dilation - x_shape: list - input shape - w_shape: list - weight shape - data_dtype: str - data type - conv_dtype: str - convolution type - groups: int - number of groups - - Returns - ------- - oshape: list - output shape - - """ - dims = len(x_shape) - assert dims in (4, 5) - - pad, stride, dilation, xshape, wshape = _prepare_global_func_params( - dims - 2, pad, stride, dilation, x_shape, w_shape - ) - oshape = np.zeros((dims), dtype=np.int32) - - func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn") - func( - tensor_format, - dims - 2, - _get_np_int32_array_handle(pad), - _get_np_int32_array_handle(stride), - _get_np_int32_array_handle(dilation), - _get_np_int32_array_handle(xshape), - _get_np_int32_array_handle(wshape), - _get_np_int32_array_handle(oshape), - data_dtype, - conv_dtype, - groups, - ) - return list(oshape) - - def conv_find_algo( tensor_format, pad, diff --git a/src/runtime/contrib/cudnn/conv_forward.cc b/src/runtime/contrib/cudnn/conv_forward.cc index 97704986792d..b7476e5106fa 100644 --- a/src/runtime/contrib/cudnn/conv_forward.cc +++ b/src/runtime/contrib/cudnn/conv_forward.cc @@ -37,89 +37,12 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); // Set Mode entry_ptr->conv_entry.mode = static_cast(mode); - // Set Format - entry_ptr->conv_entry.tensor_format = static_cast(format); - // Set Algo - entry_ptr->conv_entry.fwd_algo = static_cast(algo); + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape, + y->shape, x->dtype, conv_dtype); // Set Device entry_ptr->conv_entry.device = x->device; - // Set Data Type - entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype)); - cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); - // Dims includes N and C - int full_dims = dims + 2; - - std::vector dim(full_dims); - std::vector tensor_stride(full_dims); - - // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error - // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int - - CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); - if (dims == 2) { - // Set Desc - CUDNN_CALL(cudnnSetConvolution2dDescriptor( - entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], - dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); - int ni, ci, hi, wi; - if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { - ni = 0; - ci = 3; - hi = 1; - wi = 2; - } else { - ni = 0; - ci = 1; - hi = 2; - wi = 3; - } - - // Set Filter - CUDNN_CALL(cudnnSetFilter4dDescriptor( - entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format, - static_cast(w->shape[ni]), static_cast(w->shape[ci]), - static_cast(w->shape[hi]), static_cast(w->shape[wi]))); - // Set Input - CUDNN_CALL(cudnnSetTensor4dDescriptor( - entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type, - static_cast(x->shape[ni]), static_cast(x->shape[ci]), - static_cast(x->shape[hi]), static_cast(x->shape[wi]))); - // Set Output - CUDNN_CALL(cudnnSetTensor4dDescriptor( - entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type, - static_cast(y->shape[ni]), static_cast(y->shape[ci]), - static_cast(y->shape[hi]), static_cast(y->shape[wi]))); - } else { - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, - dilation, entry_ptr->conv_entry.mode, - entry_ptr->conv_entry.data_type)); - - // Set Filter - for (int i = 0; i < full_dims; i++) { - dim[i] = static_cast(w->shape[i]); - } - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, - entry_ptr->conv_entry.tensor_format, full_dims, - dim.data())); - // Set Input - for (int i = 0; i < full_dims; i++) { - dim[i] = static_cast(x->shape[i]); - } - GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, - dim.data(), tensor_stride.data())); - // Set Output - for (int i = 0; i < full_dims; i++) { - dim[i] = static_cast(y->shape[i]); - } - GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, - dim.data(), tensor_stride.data())); - } - - if (cudnnGetVersion() > 7000) { - CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) - } + // Set Algo + entry_ptr->conv_entry.fwd_algo = static_cast(algo); // Set workspace size_t workspace_size = 0; @@ -137,125 +60,22 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co entry_ptr->conv_entry.output_desc, y->data)); } -void OutputShape(int format, int dims, int groups, const int pad[], const int stride[], - const int dilation[], const int x_dim[], const int w_dim[], void* out_shape, - const std::string& data_dtype, const std::string& conv_dtype) { - CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - - // Set Data Type - entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype)); - cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(data_dtype)); - // Set Format - entry_ptr->conv_entry.tensor_format = static_cast(format); - // Dims includes N and C - int full_dims = dims + 2; - - // conv desc - CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, - dilation, CUDNN_CROSS_CORRELATION, - entry_ptr->conv_entry.data_type)); - - if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { - ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors"; - - // Set Input - CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.tensor_format, data_type, x_dim[0], - x_dim[3], x_dim[1], x_dim[2])); - - // filter desc - CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, data_type, - entry_ptr->conv_entry.tensor_format, w_dim[0], w_dim[3], - w_dim[1], w_dim[2])); - - CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim( - entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, static_cast(out_shape), - static_cast(out_shape) + 3, static_cast(out_shape) + 1, - static_cast(out_shape) + 2)); - } else { - // Set Input - std::vector tensor_stride(full_dims); - GetCudnnStride(full_dims, x_dim, tensor_stride.data()); - - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, - x_dim, tensor_stride.data())); - // filter desc - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, - entry_ptr->conv_entry.tensor_format, full_dims, w_dim)); - - CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim( - entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc, - entry_ptr->conv_entry.filter_desc, full_dims, static_cast(out_shape))); - } -} - void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[], const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[], const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) { CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); - - // Set Data Type - entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype)); - cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(data_dtype)); - // Set Format - entry_ptr->conv_entry.tensor_format = static_cast(format); - // Dims includes N and C - int full_dims = dims + 2; - - // conv desc - CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); - - if (format == 1) { - ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors"; - int ni = 0; - int ci = 3; - int hi = 1; - int wi = 2; - - // Set Input - CUDNN_CALL(cudnnSetTensor4dDescriptor( - entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type, - static_cast(x_dim[ni]), static_cast(x_dim[ci]), static_cast(x_dim[hi]), - static_cast(x_dim[wi]))); - - CUDNN_CALL(cudnnSetFilter4dDescriptor( - entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format, - static_cast(w_dim[ni]), static_cast(w_dim[ci]), static_cast(w_dim[hi]), - static_cast(w_dim[wi]))); - // Set Output - CUDNN_CALL(cudnnSetTensor4dDescriptor( - entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type, - static_cast(y_dim[ni]), static_cast(y_dim[ci]), static_cast(y_dim[hi]), - static_cast(y_dim[wi]))); - - CUDNN_CALL(cudnnSetConvolution2dDescriptor( - entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], - dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); - } else { - CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, - dilation, CUDNN_CROSS_CORRELATION, - entry_ptr->conv_entry.data_type)); - - std::vector tensor_stride(full_dims); - // input desc - GetCudnnStride(full_dims, x_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims, - x_dim, tensor_stride.data())); - // filter desc - CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type, - entry_ptr->conv_entry.tensor_format, full_dims, w_dim)); - - // output desc - GetCudnnStride(full_dims, y_dim, tensor_stride.data()); - CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims, - y_dim, tensor_stride.data())); - } - - if (cudnnGetVersion() > 7000) { - CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) + const int full_dims = dims + 2; + std::vector x_dim_int64(full_dims); + std::vector w_dim_int64(full_dims); + std::vector y_dim_int64(full_dims); + for (int i = 0; i < full_dims; ++i) { + x_dim_int64[i] = x_dim[i]; + w_dim_int64[i] = w_dim[i]; + y_dim_int64[i] = y_dim[i]; } + SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(), + w_dim_int64.data(), y_dim_int64.data(), String2DLDataType(data_dtype), + conv_dtype); int returned_algo_count = 0; cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT]; @@ -327,24 +147,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward") conv_dtype); }); -TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape_from_cudnn") - .set_body([](TVMArgs args, TVMRetValue* ret) { - int format = args[0]; - int dims = args[1]; - int* pad = static_cast(static_cast(args[2])); - int* stride = static_cast(static_cast(args[3])); - int* dilation = static_cast(static_cast(args[4])); - int* x_dim = static_cast(static_cast(args[5])); - int* w_dim = static_cast(static_cast(args[6])); - void* out_shape = args[7]; - std::string data_dtype = args[8]; - std::string conv_dtype = args[9]; - int groups = args[10]; - - OutputShape(format, dims, groups, pad, stride, dilation, x_dim, w_dim, out_shape, data_dtype, - conv_dtype); - }); - TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo") .set_body([](TVMArgs args, TVMRetValue* ret) { int format = args[0]; diff --git a/src/runtime/contrib/cudnn/cudnn_utils.cc b/src/runtime/contrib/cudnn/cudnn_utils.cc index 297cd9e7a361..e39c47339c7f 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.cc +++ b/src/runtime/contrib/cudnn/cudnn_utils.cc @@ -20,11 +20,16 @@ /*! * \file Use external cudnn utils function */ + #include "cudnn_utils.h" #include +#include #include +#include +#include + namespace tvm { namespace contrib { @@ -160,6 +165,96 @@ void ConvEntry::CleanWorkspace() { workspace_size = 0; } +void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int groups, + const int pad[], const int stride[], const int dilation[], int64_t x_dim[], + int64_t w_dim[], int64_t y_dim[], DLDataType data_dtype, + const std::string& conv_dtype) { + // Set Format + entry_ptr->conv_entry.tensor_format = static_cast(format); + // Set Data Type + entry_ptr->conv_entry.data_type = + CuDNNDataType::DLTypeToCuDNNType(runtime::String2DLDataType(conv_dtype)); + + cudnnDataType_t cudnn_data_type = CuDNNDataType::DLTypeToCuDNNType(data_dtype); + + // Dims includes N and C + int full_dims = dims + 2; + + std::vector dim(full_dims); + std::vector tensor_stride(full_dims); + + // Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error + // in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int + + CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups)); + if (dims == 2) { + // Set Desc + CUDNN_CALL(cudnnSetConvolution2dDescriptor( + entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0], + dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type)); + int ni, ci, hi, wi; + if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) { + ni = 0; + ci = 3; + hi = 1; + wi = 2; + } else { + ni = 0; + ci = 1; + hi = 2; + wi = 3; + } + + // Set Input + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, cudnn_data_type, + static_cast(x_dim[ni]), static_cast(x_dim[ci]), static_cast(x_dim[hi]), + static_cast(x_dim[wi]))); + // Set Filter + CUDNN_CALL(cudnnSetFilter4dDescriptor( + entry_ptr->conv_entry.filter_desc, cudnn_data_type, entry_ptr->conv_entry.tensor_format, + static_cast(w_dim[ni]), static_cast(w_dim[ci]), static_cast(w_dim[hi]), + static_cast(w_dim[wi]))); + // Set Output + CUDNN_CALL(cudnnSetTensor4dDescriptor( + entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, cudnn_data_type, + static_cast(y_dim[ni]), static_cast(y_dim[ci]), static_cast(y_dim[hi]), + static_cast(y_dim[wi]))); + } else { + ICHECK_EQ(format, 0) << "Use of layout CUDNN_TENSOR_NHWC is supported only for 4-D tensors."; + + CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride, + dilation, entry_ptr->conv_entry.mode, + entry_ptr->conv_entry.data_type)); + + // Set Filter + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(w_dim[i]); + } + CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, cudnn_data_type, + entry_ptr->conv_entry.tensor_format, full_dims, + dim.data())); + // Set Input + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(x_dim[i]); + } + GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, cudnn_data_type, + full_dims, dim.data(), tensor_stride.data())); + // Set Output + for (int i = 0; i < full_dims; i++) { + dim[i] = static_cast(y_dim[i]); + } + GetCudnnStride(full_dims, dim.data(), tensor_stride.data()); + CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, cudnn_data_type, + full_dims, dim.data(), tensor_stride.data())); + } + + if (cudnnGetVersion() > 7000) { + CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH)) + } +} + // SoftmaxEntry SoftmaxEntry::SoftmaxEntry() { CUDNN_CALL(cudnnCreateTensorDescriptor(&shape_desc)); } diff --git a/src/runtime/contrib/cudnn/cudnn_utils.h b/src/runtime/contrib/cudnn/cudnn_utils.h index 01b92d61e66e..89de0e90df90 100644 --- a/src/runtime/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/contrib/cudnn/cudnn_utils.h @@ -28,6 +28,8 @@ #include #include +#include + #include "../../cuda/cuda_common.h" namespace tvm { @@ -64,7 +66,7 @@ inline void GetCudnnStride(int nbdim, const int* dims, int* strides) { struct ConvEntry { cudnnConvolutionDescriptor_t conv_desc; - cudnnConvolutionMode_t mode; + cudnnConvolutionMode_t mode{CUDNN_CROSS_CORRELATION}; cudnnFilterDescriptor_t filter_desc; cudnnDataType_t data_type; cudnnTensorFormat_t tensor_format; @@ -103,6 +105,11 @@ struct CuDNNThreadEntry { static CuDNNThreadEntry* ThreadLocal(bool check_exists = true); }; // CuDNNThreadEntry +void SetConvDescriptors(CuDNNThreadEntry* entry_ptr, int format, int dims, int groups, + const int pad[], const int stride[], const int dilation[], int64_t x_dim[], + int64_t w_dim[], int64_t y_dim[], DLDataType data_dtype, + const std::string& conv_dtype); + } // namespace contrib } // namespace tvm diff --git a/tests/python/contrib/test_cudnn.py b/tests/python/contrib/test_cudnn.py index 7f504fcc1ed7..bc2cc80f362d 100644 --- a/tests/python/contrib/test_cudnn.py +++ b/tests/python/contrib/test_cudnn.py @@ -29,7 +29,7 @@ requires_cudnn = pytest.mark.skipif( - tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True) is None, + tvm.get_global_func("tvm.contrib.cudnn.conv2d.forward", True) is None, reason="CuDNN is not enabled", ) @@ -307,13 +307,5 @@ def conv_output_shape_kwargs(request): return request.param -@tvm.testing.requires_gpu -@requires_cudnn -def test_conv_output_shape(conv_output_shape_kwargs): - shape_from_cudnn = cudnn._conv_output_shape_from_cudnn(**conv_output_shape_kwargs) - shape_from_python = cudnn.conv_output_shape(**conv_output_shape_kwargs) - assert shape_from_cudnn == shape_from_python - - if __name__ == "__main__": sys.exit(pytest.main(sys.argv)) diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 9ac04d4933a6..da36bba96556 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -541,7 +541,7 @@ def verify_any_conv2d( kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) targets = None - if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn", True): + if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv2d.forward", True): targets = [("cuda -libs=cudnn", tvm.cuda(0))] check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets)