Skip to content

Commit

Permalink
Asymmetric padding and dilation in conv2d workload (apache#7142)
Browse files Browse the repository at this point in the history
* added asymmetric padding to conv2d workload

* fixed depthwise conv2d padding

* Added fix to include dilation in workload output width calculation

* Added missing dilation to arm_cpu/conv2d_int8.py workload

* Fixed dilation for x86 conv2d

* Improved dilation workload integration in x86

* Fixed x86 conv2d_alter_op to add dilation

* Local linting not always producing same output as CI, probably my fault

* Fixed bug, tested locally

* Abusing CI until I can figure out how to reproduce the same behaviour of running integration tests locally.

* Ammeded conv2d_int8 test

* Updated workload, improved unit tests

* Added depthwise conv2d workload test
  • Loading branch information
Wheest authored and trevor-m committed Jan 21, 2021
1 parent 046a739 commit 6c053ae
Show file tree
Hide file tree
Showing 15 changed files with 201 additions and 63 deletions.
7 changes: 4 additions & 3 deletions python/tvm/topi/arm_cpu/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
from .arm_utils import get_tiling_B_interleaved_t


def _get_default_config(cfg, data, kernel, strides, padding, out_dtype):
def _get_default_config(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""
Get default int8 schedule config for the workload
"""
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype)
is_kernel_1x1 = wkl.kernel_h == 1 and wkl.kernel_w == 1
if is_kernel_1x1:
conv2d_generic.fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes=2, num_int8_elements=4)
else:
Expand Down Expand Up @@ -65,6 +65,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
te.placeholder((num_filter, in_channel, kh, kw), dtype=kernel.dtype),
strides,
padding,
dilation,
out_dtype,
)
return nn.conv2d_NCHWc_int8_compute(
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/topi/cuda/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,10 @@ def conv2d_NCHWc_int8(cfg, data, kernel, stride, padding, dilation, layout, out_
pad_data = pad(packed_data, pad_before, pad_after, name="pad_data")

# compute the output shape
out_height = (in_height - (kernel_h - 1) * dilation_h - 1 + pad_top + pad_down) // stride_h + 1
out_width = (in_width - (kernel_w - 1) * dilation_w - 1 + pad_left + pad_right) // stride_w + 1

dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
out_height = (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1
out_width = (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1
oshape = (batch, oc_chunk, out_height, out_width, oc_block)

icc = te.reduce_axis((0, ic_chunk), name="ic_chunk")
Expand Down
15 changes: 8 additions & 7 deletions python/tvm/topi/generic/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ def fallback_schedule_cpu_common_int8(cfg, wkl, int32_lanes, num_int8_elements):
How many numbers of input int32/uint32 will be multiplied and reduced.
This is related to input channel.
"""
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
HSTR, WSTR = wkl.stride_h, wkl.stride_w
dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1
out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1

assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d, int32_lanes=%d" % (
wkl.out_filter,
Expand Down Expand Up @@ -85,10 +86,10 @@ def fallback_schedule_cpu_1x1_int8(cfg, wkl, int32_lanes, num_int8_elements):
How many numbers of input int32/uint32 will be multiplied and reduced.
This is related to input channel.
"""
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
HSTR, WSTR = wkl.stride_h, wkl.stride_w
out_height = (wkl.height + pt + pb - wkl.kernel_h) // HSTR + 1
out_width = (wkl.width + pl + pr - wkl.kernel_w) // WSTR + 1

assert wkl.out_filter % int32_lanes == 0, "wkl.out_filter=%d, int32_lanes=%d" % (
wkl.out_filter,
Expand Down
43 changes: 34 additions & 9 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@
"in_filter",
"groups",
"out_filter",
"hkernel",
"wkernel",
"hpad",
"wpad",
"hstride",
"wstride",
"kernel_h",
"kernel_w",
"padt",
"padl",
"padb",
"padr",
"dilation_h",
"dilation_w",
"stride_h",
"stride_w",
],
)

Expand Down Expand Up @@ -154,7 +158,7 @@ def conv2d_infer_layout(workload, cfg):
raise ValueError("missing register for topi.nn.conv2d_infer_layout")


def _get_workload(data, kernel, stride, padding, out_dtype, data_layout="NCHW"):
def _get_workload(data, kernel, stride, padding, dilation, out_dtype, data_layout="NCHW"):
""" Get the workload structure. """
if data_layout == "NCHW":
_, CI, IH, IW = get_const_tuple(data.shape)
Expand All @@ -170,7 +174,10 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout="NCHW"):
else:
KH, KW, CIG, CO = get_const_tuple(kernel.shape)

HPAD, WPAD, _, _ = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW)))
pt, pl, pb, pr = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW)))
dilation_h, dilation_w = (
dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
)
GRPS = CI // CIG
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
Expand All @@ -182,7 +189,25 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout="NCHW"):
'{} vs. {}".format(
data.dtype, kernel.dtype
)
return Workload(data.dtype, out_dtype, IH, IW, CI, GRPS, CO, KH, KW, HPAD, WPAD, HSTR, WSTR)
return Workload(
data.dtype,
out_dtype,
IH,
IW,
CI,
GRPS,
CO,
KH,
KW,
pt,
pl,
pb,
pr,
dilation_h,
dilation_w,
HSTR,
WSTR,
)


def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
Expand Down
33 changes: 23 additions & 10 deletions python/tvm/topi/nn/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,28 @@
"width",
"in_filter",
"out_filter",
"hkernel",
"wkernel",
"hpad",
"wpad",
"hstride",
"wstride",
"kernel_h",
"kernel_w",
"padt",
"padl",
"padb",
"padr",
"dilation_h",
"dilation_w",
"stride_h",
"stride_w",
],
)


def _get_workload(data, kernel, stride, padding, out_dtype):
def _get_workload(data, kernel, stride, padding, dilation, out_dtype):
""" Get the workload structure. """
_, in_channel, height, width = [x.value for x in data.shape]
channel, channel_multiplier, kh, kw = [x.value for x in kernel.shape]
out_channel = channel * channel_multiplier
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
dilation_h, dilation_w = (
dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
Expand All @@ -62,6 +68,9 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
'{} vs. {}".format(
data.dtype, kernel.dtype
)
dilated_kernel_h = (kh - 1) * dilation_h + 1
dilated_kernel_w = (kw - 1) * dilation_w + 1
pt, pl, pb, pr = get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
return Workload(
data.dtype,
out_dtype,
Expand All @@ -71,8 +80,12 @@ def _get_workload(data, kernel, stride, padding, out_dtype):
out_channel,
kh,
kw,
HPAD,
WPAD,
pt,
pl,
pb,
pr,
dilation_h,
dilation_w,
HSTR,
WSTR,
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/testing/depthwise_conv2d_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def depthwise_conv2d_python_nchw(input_np, filter_np, stride, padding):
0 : (in_height - filter_height + 1) : stride_h,
0 : (in_width - filter_width + 1) : stride_w,
]
if padding == "SAME":
elif padding == "SAME":
out_channel = in_channel * channel_multiplier
out_height = np.int(np.ceil(float(in_height) / float(stride_h)))
out_width = np.int(np.ceil(float(in_width) / float(stride_w)))
Expand Down
16 changes: 10 additions & 6 deletions python/tvm/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@


def _get_default_config(
cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout="NCHW"
cfg, data, kernel, strides, padding, dilation, out_dtype, is_depthwise=False, layout="NCHW"
):
"""
Get default schedule config for the workload
Expand All @@ -48,13 +48,13 @@ def _get_default_config(
static_data_shape.append(dim)
data = te.placeholder(static_data_shape, dtype=data.dtype)
if is_depthwise:
wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, out_dtype)
wkl = _get_depthwise_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype)
from .depthwise_conv2d import _fallback_schedule

_fallback_schedule(cfg, wkl)
else:
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype, layout)
is_kernel_1x1 = wkl.kernel_h == 1 and wkl.kernel_w == 1
if is_kernel_1x1:
conv2d_avx_1x1._fallback_schedule(cfg, wkl)
else:
Expand All @@ -69,8 +69,11 @@ def _conv2d_infer_layout(workload, cfg):
idxdiv = tvm.tir.indexdiv

pt, pl, pb, pr = get_pad_tuple(padding, (k_height, k_width))
out_height = idxdiv(in_height + pt + pb - k_height, strides[0]) + 1
out_width = idxdiv(in_width + pl + pr - k_width, strides[1]) + 1
hdilation, wdilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
dilated_kernel_h = (k_height - 1) * hdilation + 1
dilated_kernel_w = (k_width - 1) * wdilation + 1
out_height = idxdiv(in_height + pt + pb - dilated_kernel_h, strides[0]) + 1
out_width = idxdiv(in_width + pl + pr - dilated_kernel_w, strides[1]) + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
in_layout = "NCHW%dc" % tile_ic
Expand Down Expand Up @@ -208,6 +211,7 @@ def conv2d_NCHWc(cfg, data, kernel, strides, padding, dilation, layout, out_layo
),
strides,
padding,
dilation,
out_dtype,
)

Expand Down
30 changes: 27 additions & 3 deletions python/tvm/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,15 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
if data_layout == "NCHW" and kernel_layout == "OIHW":
if cfg.is_fallback:
_get_default_config(
cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, False, data_layout
cfg,
data_tensor,
kernel_tensor,
strides,
padding,
dilation,
out_dtype,
False,
data_layout,
)
batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)
Expand Down Expand Up @@ -142,7 +150,15 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
assert data_layout == "NCHW" and kernel_layout == "OIHW"
if cfg.is_fallback:
_get_default_config_int8(
cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, False, data_layout
cfg,
data_tensor,
kernel_tensor,
strides,
padding,
dilation,
out_dtype,
False,
data_layout,
)

batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
Expand Down Expand Up @@ -198,7 +214,15 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
if data_layout == "NCHW" and kernel_layout == "OIHW":
if cfg.is_fallback:
_get_default_config(
cfg, data_tensor, kernel_tensor, strides, padding, out_dtype, True, data_layout
cfg,
data_tensor,
kernel_tensor,
strides,
padding,
dilation,
out_dtype,
True,
data_layout,
)

batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/topi/x86/conv2d_avx_1x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@

def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_height = (wkl.height + 2 * HPAD - wkl.hkernel) // HSTR + 1
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
HSTR, WSTR = wkl.stride_h, wkl.stride_w
dilated_kernel_h = (wkl.kernel_h - 1) * wkl.dilation_h + 1
dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1

out_height = (wkl.height + pt + pb - dilated_kernel_h) // HSTR + 1
out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1

oc_bn = 1
for bn in range(simd_width, 0, -1):
Expand Down
14 changes: 8 additions & 6 deletions python/tvm/topi/x86/conv2d_avx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@

def _fallback_schedule(cfg, wkl):
simd_width = get_fp32_len()
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
HSTR, WSTR = wkl.stride_h, wkl.stride_w
dilated_kernel_w = (wkl.kernel_w - 1) * wkl.dilation_w + 1

out_width = (wkl.width + pl + pr - dilated_kernel_w) // WSTR + 1

oc_bn = 1
for bn in range(simd_width, 0, -1):
Expand All @@ -56,9 +58,9 @@ def _fallback_schedule(cfg, wkl):


def _fallback_schedule_int8(cfg, wkl):
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride
out_width = (wkl.width + 2 * WPAD - wkl.wkernel) // WSTR + 1
pt, pl, pb, pr = wkl.padt, wkl.padl, wkl.padb, wkl.padr
HSTR, WSTR = wkl.stride_h, wkl.stride_w
out_width = (wkl.width + pl + pr - wkl.kernel_w) // WSTR + 1

oc_bn = 16
assert wkl.out_filter % oc_bn == 0
Expand Down
14 changes: 9 additions & 5 deletions python/tvm/topi/x86/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


def _get_default_config_int8(
cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout="NCHW"
cfg, data, kernel, strides, padding, dilation, out_dtype, is_depthwise=False, layout="NCHW"
):
"""
Get default schedule config for the workload
Expand All @@ -45,8 +45,8 @@ def _get_default_config_int8(

_fallback_schedule(cfg, wkl)
else:
wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout)
is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1
wkl = _get_conv2d_workload(data, kernel, strides, padding, dilation, out_dtype, layout)
is_kernel_1x1 = wkl.kernel_h == 1 and wkl.kernel_w == 1
if is_kernel_1x1:
conv2d_generic.fallback_schedule_cpu_1x1_int8(
cfg, wkl, int32_lanes=16, num_int8_elements=4
Expand Down Expand Up @@ -138,8 +138,11 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
is_kernel_1x1 = kernel_height == 1 and kernel_width == 1
pt, pl, pb, pr = get_pad_tuple(padding, (kernel_height, kernel_width))
sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
oh = (ih - kernel_height + pt + pb) // sh + 1
ow = (iw - kernel_width + pl + pr) // sw + 1
dh, dw = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
dilated_kernel_h = (kernel_height - 1) * dh + 1
dilated_kernel_w = (kernel_width - 1) * dw + 1
oh = (ih - dilated_kernel_h + pt + pb) // sh + 1
ow = (iw - dilated_kernel_w + pl + pr) // sw + 1

cfg.define_split("tile_ic", in_channel, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0)
cfg.define_split("tile_oc", num_filter, num_outputs=2, filter=lambda y: y.size[-1] % 16 == 0)
Expand All @@ -159,6 +162,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
),
strides,
padding,
dilation,
out_dtype,
)

Expand Down
Loading

0 comments on commit 6c053ae

Please sign in to comment.