Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUBLAS, CUDNN] Support dynamic batch size #7194

Merged
merged 3 commits into from
Jan 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 51 additions & 30 deletions python/tvm/contrib/cudnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,36 +342,57 @@ def conv_forward(x, w, pad, stride, dilation, conv_mode, tensor_format, algo, co
conv_dtype = x.dtype if conv_dtype is None else conv_dtype
pad, stride, dilation, _, _ = _prepare_global_func_params(dims - 2, pad, stride, dilation)

oshape = conv_output_shape(
tensor_format,
pad,
stride,
dilation,
list(x.shape),
list(w.shape),
x.dtype,
conv_dtype,
groups,
)
if algo == -1:
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
# using INT8 data type, CuDNN will crash down.
# On the other hand, CuDNN only support IMPLICIT_PRECOMP_GEMM at NHWC format
if tensor_format == 1 and conv_dtype == "int32":
algo = 1
else:
algo = conv_find_algo(
tensor_format,
pad,
stride,
dilation,
list(x.shape),
list(w.shape),
oshape,
x.dtype,
conv_dtype,
groups,
)
x_shape = list(x.shape)

if isinstance(x.shape[0], tvm.tir.expr.IntImm):
oshape = conv_output_shape(
tensor_format,
pad,
stride,
dilation,
x_shape,
list(w.shape),
x.dtype,
conv_dtype,
groups,
)
if algo == -1:
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
# using INT8 data type, CuDNN will crash down.
# On the other hand, CuDNN only support IMPLICIT_PRECOMP_GEMM at NHWC format
if tensor_format == 1 and conv_dtype == "int32":
algo = 1
else:
algo = conv_find_algo(
tensor_format,
pad,
stride,
dilation,
list(x.shape),
list(w.shape),
oshape,
x.dtype,
conv_dtype,
groups,
)
else:
# The dynamic batch size case, pretend this is a single batch
x_shape[0] = 1
oshape = conv_output_shape(
tensor_format,
pad,
stride,
dilation,
x_shape,
list(w.shape),
x.dtype,
conv_dtype,
groups,
)
oshape[0] = x.shape[0]
# This picks CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM
# It seems this is the fastest among algorithms that are always applicable
algo = 1

if dims == 4:
return te.extern(
Expand Down
24 changes: 13 additions & 11 deletions python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,19 @@ def conv2d_cudnn(
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
OH = (H + pt + pb - KH) // stride_h + 1
OW = (W + pl + pr - KW) // stride_w + 1
cfg.add_flop(
groups
* 2
* N
* OH
* OW
* CO
* CI
* ((KH - 1) * dilation_h + 1)
* ((KW - 1) * dilation_w + 1)
)

if isinstance(N, int):
cfg.add_flop(
groups
* 2
* N
* OH
* OW
* CO
* CI
* ((KH - 1) * dilation_h + 1)
* ((KW - 1) * dilation_w + 1)
)

if data.dtype == "int8" or kernel.dtype == "int8":
if layout == "NCHW":
Expand Down
26 changes: 14 additions & 12 deletions python/tvm/topi/cuda/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,18 +206,20 @@ def conv3d_cudnn(
OD = (D + 2 * pad_d - KD) // stride_d + 1
OH = (H + 2 * pad_h - KH) // stride_h + 1
OW = (W + 2 * pad_w - KW) // stride_w + 1
cfg.add_flop(
2
* N
* OD
* OH
* OW
* CO
* CI
* ((KD - 1) * dilation_d + 1)
* ((KH - 1) * dilation_h + 1)
* ((KW - 1) * dilation_w + 1)
)

if isinstance(N, int):
cfg.add_flop(
2
* N
* OD
* OH
* OW
* CO
* CI
* ((KD - 1) * dilation_d + 1)
* ((KH - 1) * dilation_h + 1)
* ((KW - 1) * dilation_w + 1)
)

return cudnn.conv_forward(
data,
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/topi/cuda/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def dense_cublas(cfg, data, weight, bias=None, out_dtype=None):
batch, in_dim = data.shape
out_dim, _ = weight.shape
matmul = cublas.matmul(data, weight, False, True)
cfg.add_flop(batch * in_dim * out_dim * 2)
if isinstance(batch, int):
cfg.add_flop(batch * in_dim * out_dim * 2)
if bias is not None:
matmul = te.compute(
(batch, out_dim), lambda i, j: matmul[i, j] + bias[j], tag=tag.BROADCAST
Expand Down
50 changes: 41 additions & 9 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,11 @@ def check_result(
str(e),
str(r),
)
return

if flatten:
r = r.flatten()
e = e.flatten()
tvm.testing.assert_allclose(r, e, atol=2e-6)
else:
if flatten:
r = r.flatten()
e = e.flatten()
tvm.testing.assert_allclose(r, e, atol=2e-6)


def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op):
Expand Down Expand Up @@ -454,6 +453,7 @@ def verify_any_conv2d(
dilation,
static_data_shape,
ref_out_shape,
use_cudnn=False,
):
mod = tvm.IRModule()
dtype = "float32"
Expand All @@ -463,7 +463,12 @@ def verify_any_conv2d(
mod["main"] = relay.Function([data, kernel], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
kernel_np = np.random.uniform(size=kernel_shape).astype(dtype)
check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True)

targets = None
if use_cudnn and tvm.get_global_func("tvm.contrib.cudnn.conv.output_shape", True):
targets = [("cuda -libs=cudnn", tvm.gpu(0))]

check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True, targets=targets)


# TODO(@kevinthesun): Support dynamic input height and width.
Expand All @@ -487,6 +492,16 @@ def test_any_conv2d():
(2, 64, 224, 224),
(2, 64, 222, 222),
)
verify_any_conv2d(
(relay.Any(), 64, 224, 224),
(64, 64, 3, 3),
(1, 1),
(1, 1),
(1, 1),
(1, 64, 224, 224),
(1, 64, 224, 224),
use_cudnn=True,
)


def verify_any_conv2d_NCHWc(
Expand Down Expand Up @@ -724,7 +739,13 @@ def test_any_batch_flatten():


def verify_any_dense(
data_shape, weight_shape, units, static_data_shape, static_weight_shape, ref_out_shape
data_shape,
weight_shape,
units,
static_data_shape,
static_weight_shape,
ref_out_shape,
use_cublas=False,
):
mod = tvm.IRModule()
dtype = "float32"
Expand All @@ -734,7 +755,12 @@ def verify_any_dense(
mod["main"] = relay.Function([data, weight], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
weight_np = np.random.uniform(size=static_weight_shape).astype(dtype)
check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True)

targets = None
if use_cublas and tvm.get_global_func("tvm.contrib.cublas.matmul", True):
targets = [("cuda -libs=cublas", tvm.gpu(0))]

check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True, targets=targets)


# TODO(tvm-team) Fix dense schedule
Expand All @@ -744,6 +770,12 @@ def test_any_dense():
verify_any_dense(any_dims(2), (50, relay.Any()), 50, (4, 40), (50, 40), (4, 50))


@tvm.testing.uses_gpu
def test_any_dense_dynamic_batch():
verify_any_dense((relay.Any(), 40), (50, 40), 50, (4, 40), (50, 40), (4, 50))
verify_any_dense((relay.Any(), 40), (50, 40), 50, (4, 40), (50, 40), (4, 50), use_cublas=True)


@tvm.testing.uses_gpu
def verify_any_pad(data_shape, pad_width, static_data_shape):
mod = tvm.IRModule()
Expand Down