Skip to content

Commit

Permalink
[CUDNN] Refactor descriptor initialization, remove `cudnn.conv.output…
Browse files Browse the repository at this point in the history
…_shape_from_cudnn` (#9948)

* Introduce SetConvdescriptors to refactor cudnn/conv_forward.cc

* more refactor

* remove cudnn get output

* cpplint
  • Loading branch information
masahi authored Jan 18, 2022
1 parent 1e5373f commit 211291f
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 286 deletions.
62 changes: 0 additions & 62 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,68 +285,6 @@ def conv_output_shape(
return output


def _conv_output_shape_from_cudnn(
tensor_format, pad, stride, dilation, x_shape, w_shape, data_dtype, conv_dtype, groups=1
):
"""Get output shape of 2D or 3D convolution. The output of this
function should be identical to that of conv_output_shape, but
requires a GPU with CuDNN to be present. This is maintained for
testing purposes to validate the output of conv_output_shape.
Paramters
---------
tensor_format: int
0: CUDNN_TENSOR_NCHW
1: CUDNN_TENSOR_NHWC
2: CUDNN_TENSOR_NCHW_VECT_C
pad: int or list
padding
stride: int or list
stride
dilation: int or list
dilation
x_shape: list
input shape
w_shape: list
weight shape
data_dtype: str
data type
conv_dtype: str
convolution type
groups: int
number of groups
Returns
-------
oshape: list
output shape
"""
dims = len(x_shape)
assert dims in (4, 5)

pad, stride, dilation, xshape, wshape = _prepare_global_func_params(
dims - 2, pad, stride, dilation, x_shape, w_shape
)
oshape = np.zeros((dims), dtype=np.int32)

func = tvm._ffi.get_global_func("tvm.contrib.cudnn.conv.output_shape_from_cudnn")
func(
tensor_format,
dims - 2,
_get_np_int32_array_handle(pad),
_get_np_int32_array_handle(stride),
_get_np_int32_array_handle(dilation),
_get_np_int32_array_handle(xshape),
_get_np_int32_array_handle(wshape),
_get_np_int32_array_handle(oshape),
data_dtype,
conv_dtype,
groups,
)
return list(oshape)


def conv_find_algo(
tensor_format,
pad,
Expand Down
228 changes: 15 additions & 213 deletions src/runtime/contrib/cudnn/conv_forward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,89 +37,12 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Mode
entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// Set Algo
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x->shape, w->shape,
y->shape, x->dtype, conv_dtype);
// Set Device
entry_ptr->conv_entry.device = x->device;
// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
// Dims includes N and C
int full_dims = dims + 2;

std::vector<int> dim(full_dims);
std::vector<int> tensor_stride(full_dims);

// Note: For 2D tenor, using ND setters causes CUDNN_STATUS_NOT_SUPPORTED error
// in following cudnnGetConvolutionForwardWorkspaceSize() when data type is fp16, int

CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
if (dims == 2) {
// Set Desc
CUDNN_CALL(cudnnSetConvolution2dDescriptor(
entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0],
dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type));
int ni, ci, hi, wi;
if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
ni = 0;
ci = 3;
hi = 1;
wi = 2;
} else {
ni = 0;
ci = 1;
hi = 2;
wi = 3;
}

// Set Filter
CUDNN_CALL(cudnnSetFilter4dDescriptor(
entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format,
static_cast<int>(w->shape[ni]), static_cast<int>(w->shape[ci]),
static_cast<int>(w->shape[hi]), static_cast<int>(w->shape[wi])));
// Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type,
static_cast<int>(x->shape[ni]), static_cast<int>(x->shape[ci]),
static_cast<int>(x->shape[hi]), static_cast<int>(x->shape[wi])));
// Set Output
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type,
static_cast<int>(y->shape[ni]), static_cast<int>(y->shape[ci]),
static_cast<int>(y->shape[hi]), static_cast<int>(y->shape[wi])));
} else {
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
dilation, entry_ptr->conv_entry.mode,
entry_ptr->conv_entry.data_type));

// Set Filter
for (int i = 0; i < full_dims; i++) {
dim[i] = static_cast<int>(w->shape[i]);
}
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
entry_ptr->conv_entry.tensor_format, full_dims,
dim.data()));
// Set Input
for (int i = 0; i < full_dims; i++) {
dim[i] = static_cast<int>(x->shape[i]);
}
GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
dim.data(), tensor_stride.data()));
// Set Output
for (int i = 0; i < full_dims; i++) {
dim[i] = static_cast<int>(y->shape[i]);
}
GetCudnnStride(full_dims, dim.data(), tensor_stride.data());
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
dim.data(), tensor_stride.data()));
}

if (cudnnGetVersion() > 7000) {
CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
}
// Set Algo
entry_ptr->conv_entry.fwd_algo = static_cast<cudnnConvolutionFwdAlgo_t>(algo);

// Set workspace
size_t workspace_size = 0;
Expand All @@ -137,125 +60,22 @@ void ConvolutionForward(int mode, int format, int algo, int dims, int groups, co
entry_ptr->conv_entry.output_desc, y->data));
}

void OutputShape(int format, int dims, int groups, const int pad[], const int stride[],
const int dilation[], const int x_dim[], const int w_dim[], void* out_shape,
const std::string& data_dtype, const std::string& conv_dtype) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();

// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(data_dtype));
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// Dims includes N and C
int full_dims = dims + 2;

// conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
dilation, CUDNN_CROSS_CORRELATION,
entry_ptr->conv_entry.data_type));

if (entry_ptr->conv_entry.tensor_format == CUDNN_TENSOR_NHWC) {
ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors";

// Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.tensor_format, data_type, x_dim[0],
x_dim[3], x_dim[1], x_dim[2]));

// filter desc
CUDNN_CALL(cudnnSetFilter4dDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
entry_ptr->conv_entry.tensor_format, w_dim[0], w_dim[3],
w_dim[1], w_dim[2]));

CUDNN_CALL(cudnnGetConvolution2dForwardOutputDim(
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.filter_desc, static_cast<int*>(out_shape),
static_cast<int*>(out_shape) + 3, static_cast<int*>(out_shape) + 1,
static_cast<int*>(out_shape) + 2));
} else {
// Set Input
std::vector<int> tensor_stride(full_dims);
GetCudnnStride(full_dims, x_dim, tensor_stride.data());

CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
x_dim, tensor_stride.data()));
// filter desc
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
entry_ptr->conv_entry.tensor_format, full_dims, w_dim));

CUDNN_CALL(cudnnGetConvolutionNdForwardOutputDim(
entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.filter_desc, full_dims, static_cast<int*>(out_shape)));
}
}

void FindAlgo(int format, int dims, int groups, const int pad[], const int stride[],
const int dilation[], const int x_dim[], const int w_dim[], const int y_dim[],
const std::string& data_dtype, const std::string& conv_dtype, TVMRetValue* ret) {
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();

// Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(data_dtype));
// Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// Dims includes N and C
int full_dims = dims + 2;

// conv desc
CUDNN_CALL(cudnnSetConvolutionGroupCount(entry_ptr->conv_entry.conv_desc, groups));

if (format == 1) {
ICHECK_EQ(full_dims, 4) << "Use of layout CUDNN_TENSOR_NHWC is only supported for 4d tensors";
int ni = 0;
int ci = 3;
int hi = 1;
int wi = 2;

// Set Input
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->conv_entry.input_desc, entry_ptr->conv_entry.tensor_format, data_type,
static_cast<int>(x_dim[ni]), static_cast<int>(x_dim[ci]), static_cast<int>(x_dim[hi]),
static_cast<int>(x_dim[wi])));

CUDNN_CALL(cudnnSetFilter4dDescriptor(
entry_ptr->conv_entry.filter_desc, data_type, entry_ptr->conv_entry.tensor_format,
static_cast<int>(w_dim[ni]), static_cast<int>(w_dim[ci]), static_cast<int>(w_dim[hi]),
static_cast<int>(w_dim[wi])));
// Set Output
CUDNN_CALL(cudnnSetTensor4dDescriptor(
entry_ptr->conv_entry.output_desc, entry_ptr->conv_entry.tensor_format, data_type,
static_cast<int>(y_dim[ni]), static_cast<int>(y_dim[ci]), static_cast<int>(y_dim[hi]),
static_cast<int>(y_dim[wi])));

CUDNN_CALL(cudnnSetConvolution2dDescriptor(
entry_ptr->conv_entry.conv_desc, pad[0], pad[1], stride[0], stride[1], dilation[0],
dilation[1], entry_ptr->conv_entry.mode, entry_ptr->conv_entry.data_type));
} else {
CUDNN_CALL(cudnnSetConvolutionNdDescriptor(entry_ptr->conv_entry.conv_desc, dims, pad, stride,
dilation, CUDNN_CROSS_CORRELATION,
entry_ptr->conv_entry.data_type));

std::vector<int> tensor_stride(full_dims);
// input desc
GetCudnnStride(full_dims, x_dim, tensor_stride.data());
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.input_desc, data_type, full_dims,
x_dim, tensor_stride.data()));
// filter desc
CUDNN_CALL(cudnnSetFilterNdDescriptor(entry_ptr->conv_entry.filter_desc, data_type,
entry_ptr->conv_entry.tensor_format, full_dims, w_dim));

// output desc
GetCudnnStride(full_dims, y_dim, tensor_stride.data());
CUDNN_CALL(cudnnSetTensorNdDescriptor(entry_ptr->conv_entry.output_desc, data_type, full_dims,
y_dim, tensor_stride.data()));
}

if (cudnnGetVersion() > 7000) {
CUDNN_CALL(cudnnSetConvolutionMathType(entry_ptr->conv_entry.conv_desc, CUDNN_TENSOR_OP_MATH))
const int full_dims = dims + 2;
std::vector<int64_t> x_dim_int64(full_dims);
std::vector<int64_t> w_dim_int64(full_dims);
std::vector<int64_t> y_dim_int64(full_dims);
for (int i = 0; i < full_dims; ++i) {
x_dim_int64[i] = x_dim[i];
w_dim_int64[i] = w_dim[i];
y_dim_int64[i] = y_dim[i];
}
SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation, x_dim_int64.data(),
w_dim_int64.data(), y_dim_int64.data(), String2DLDataType(data_dtype),
conv_dtype);

int returned_algo_count = 0;
cudnnConvolutionFwdAlgoPerf_t perf_results[CUDNN_CONVOLUTION_FWD_ALGO_COUNT];
Expand Down Expand Up @@ -327,24 +147,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv3d.forward")
conv_dtype);
});

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.output_shape_from_cudnn")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int format = args[0];
int dims = args[1];
int* pad = static_cast<int*>(static_cast<void*>(args[2]));
int* stride = static_cast<int*>(static_cast<void*>(args[3]));
int* dilation = static_cast<int*>(static_cast<void*>(args[4]));
int* x_dim = static_cast<int*>(static_cast<void*>(args[5]));
int* w_dim = static_cast<int*>(static_cast<void*>(args[6]));
void* out_shape = args[7];
std::string data_dtype = args[8];
std::string conv_dtype = args[9];
int groups = args[10];

OutputShape(format, dims, groups, pad, stride, dilation, x_dim, w_dim, out_shape, data_dtype,
conv_dtype);
});

TVM_REGISTER_GLOBAL("tvm.contrib.cudnn.conv.find_algo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int format = args[0];
Expand Down
Loading

0 comments on commit 211291f

Please sign in to comment.