From fec4a99225a7cecf27396bd3587d6974f0f64fcc Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 13:54:17 -0700 Subject: [PATCH 01/17] Support the 1x1 int8 conv with NHWC layout and weight packing --- topi/python/topi/nn/conv2d.py | 24 +++- topi/python/topi/x86/conv2d.py | 50 +++++++-- topi/python/topi/x86/conv2d_avx_1x1.py | 106 +++++++++++++++++- .../python/test_topi_conv2d_nhwc_pack_int8.py | 71 ++++++++++++ 4 files changed, 237 insertions(+), 14 deletions(-) create mode 100644 topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index a67f608d26dc5..1fd7bd33378e8 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -12,7 +12,7 @@ # workload description of conv2d Workload = namedtuple('Workload', - ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter', + ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups', 'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) @tvm.target.generic_func @@ -79,11 +79,25 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F): return None -def _get_workload(data, kernel, stride, padding, out_dtype): +def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): """ Get the workload structure. """ - _, CI, IH, IW = [x.value for x in data.shape] - CO, _, KH, KW = [x.value for x in kernel.shape] + if data_layout == 'NCHW': + _, CI, IH, IW = [x.value for x in data.shape] + elif data_layout == 'NHWC': + _, IH, IW, CI = [x.value for x in data.shape] + elif data_layout == 'HWCN': + IH, IW, CI, _ = [x.value for x in data.shape] + else: + raise ValueError("not support this layout {} yet".format(data_layout)) + + + if data_layout == 'NHWC': + KH, KW, CO, CIG = [x.value for x in kernel.shape] + else: + CO, CIG, KH, KW = [x.value for x in kernel.shape] + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + GRPS = CI // CIG if isinstance(stride, (tuple, list)): HSTR, WSTR = stride else: @@ -91,7 +105,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype): assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \ "Do not support inputs with different data types now. ' \ '{} vs. {}".format(data.dtype, kernel.dtype) - return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) + return Workload(data.dtype, out_dtype, IH, IW, CI, GRPS, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index c6367d07876be..9934843286222 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -20,7 +20,7 @@ logger = logging.getLogger('topi') -def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): +def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout='NCHW'): """ Get default schedule config for the workload """ @@ -29,7 +29,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth from .depthwise_conv2d import _fallback_schedule _fallback_schedule(cfg, wkl) else: - wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) + wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 if is_kernel_1x1: conv2d_avx_1x1._fallback_schedule(cfg, wkl) @@ -44,6 +44,9 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): if layout == 'NCHW': n, ic, h, w = dshape oc, _, kh, kw = kshape + elif layout == 'NHWC': + n, h, w, ic = dshape + oc, _, kh, kw = kshape else: raise ValueError("Not support this layout {} with " "schedule template.".format(layout)) @@ -63,12 +66,14 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): cfg.define_knob("unroll_kw", [True, False]) -@autotvm.register_topi_compute(conv2d, 'cpu', 'direct') +@autotvm.register_topi_compute(conv2d, 'cpu', ['direct']) def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): out_dtype = data.dtype if out_dtype is None else out_dtype padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + + _, _, kh, kw = get_const_tuple(kernel.shape) if layout == 'NCHW': _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout) if cfg.is_fallback: @@ -77,7 +82,13 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out padding, dilation, layout, out_dtype) if layout == 'HWCN': return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype) - if layout == 'NHWC': + elif layout == 'NHWC' and kh == 1 and kw == 1 and kernel.dtype == "int8": + if cfg.is_fallback: + _get_default_config(cfg, data, kernel, strides, padding, out_dtype, False, layout) + # specialize for INT8 1X1 conv on X86 + return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides, + padding, dilation, out_dtype) + elif layout == 'NHWC': return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) raise ValueError("not support this layout {} yet".format(layout)) @@ -196,8 +207,9 @@ def traverse(op): return s -@generic.schedule_conv2d_nhwc.register("cpu") -def schedule_conv2d_nhwc(outs): +# @generic.schedule_conv2d_nhwc.register("cpu") +@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc, 'cpu', ['direct']) +def schedule_conv2d_nhwc(cfg, outs): """Create schedule for tensors""" s = tvm.create_schedule([x.op for x in outs]) output_op = outs[0].op @@ -219,7 +231,31 @@ def traverse(op): if tensor.op.input_tensors and tensor.op not in scheduled_ops: traverse(tensor.op) - if 'conv2d_nhwc' in op.tag: + if 'conv2d_nhwc_pack_int8' in op.tag: + conv_out = op.output(0) + kernel = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = data_vec.op.input_tensors[0] \ + if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ + else data_vec + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + args = [s, cfg, data_vec, conv_out, outs[0]] + if data.dtype == 'uint8': + # int8 conv kernel is 7-dim + kh, kw, _, _, _ = get_const_tuple(kernel.shape) + if kh == 1 and kw == 1: + conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args) + else: + raise ValueError("Only support 1x1 kernel with " + "schedule template.") + else: + raise ValueError("Not support this data type {} with " + "schedule template.".format(data.dtype)) + + elif 'conv2d_nhwc' in op.tag: conv = op.output(0) kernel = op.input_tensors[1] if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index d44e3899293da..81bfd81602662 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -4,8 +4,9 @@ import tvm from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity -from ..nn.util import infer_pad -from ..util import get_const_tuple +from ..nn.pad import pad +from ..nn.util import infer_pad, get_pad_tuple +from ..util import get_const_tuple, simplify from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .check_targets import check_skylake from .util import get_fp32_len @@ -235,3 +236,104 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): s[O].parallel(parallel_axis) return s + + +def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype): + # more assertion for the shapes + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = Input.shape + kernel_h, kernel_w, num_filter, channel = Filter.shape + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + out_channel = num_filter + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + # todo: padding filter to accomodate the intrinsic + + # packing the Filter to let memory access be consecutive for AVX512 intrinsic + # Done in pre-compute stage + packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4) + PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e], name="packed_filter") + + rc = tvm.reduce_axis((0, in_channel), name='rc') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + Output = tvm.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: tvm.sum( + PaddedInput[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + PackW[ry, rx, ff/16, (rc/4)*16+ff%16, rc%4].astype(out_dtype), axis=[ry, rx, rc]), + name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8") + return Output + + +def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last): + """ + Defines the schedule for the int8 nhwc layout. For 1x1 conv, it + is a matrix-multiply operation by using nhwc layout. We will do + packing of weight to make the address access be friendly to int8 + intrinsic + """ + target = tvm.target.current_target(allow_none=False) + int32_lanes = -1 + if check_skylake(target): + int32_lanes = 16 + else: + return s + assert int32_lanes != -1 + + # assertion to fail the unhandled case + _, _, _, ic_num = get_const_tuple(data.shape) + _, _, _, oc_num = get_const_tuple(conv_out.shape) + assert ic_num % 4 == 0 + assert oc_num % 16 == 0 + + ic_factor, oc_factor = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + # schedule data + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ih, iw, ic = s[A].op.axis + d_ic_chunk, d_ic_block = s[A].split(ic, factor=4) + s[A].vectorize(d_ic_block) + + C, O = conv_out, last + + batch, oh, ow, oc = s[C].op.axis + kh, kw, ic = s[C].op.reduce_axis + # match the x86 intrinsic + ic_outer, ic_inner = s[C].split(ic, factor=4) + oc_outer, oc_inner = s[C].split(oc, factor=int32_lanes) + + ic_f_outer, ic_s_outer = s[C].split(ic_outer, factor=ic_factor) + s[C].reorder(oc_outer, oh, ow, ic_f_outer, ic_s_outer, kh, kw, oc_inner, ic_inner) + + pc = dot_16x1x16_int8_int8_int32() + s[C].tensorize(oc_inner, pc) + + if C != O: + batch, last_oh, last_ow, last_oc = s[O].op.axis + oc_chunk, oc_block = s[O].split(ochannel, 16) + # not saw perf improvement to split oh/ow here + s[O].vectorize(oc_block) + + return s + diff --git a/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py new file mode 100644 index 0000000000000..441d9880f6ffe --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py @@ -0,0 +1,71 @@ +"""Example code to do convolution.""" +import os +import numpy as np +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import FallbackConfigEntity +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + + +def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8') + + a_shape = get_const_tuple(A.shape) + w_shape = (kernel, kernel, in_channel, num_filter) + dtype = A.dtype + + #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(W.dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + + with tvm.target.create(device): + B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") + s = topi.generic.schedule_conv2d_nhwc([B]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], device) + func(a, w, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm -mcpu=skylake-avx512']: + check_device(device) + + +class DefaultFallback(autotvm.FallbackContext): + def _query_inside(self, target, workload): + key = (target, workload) + if key in self.memory: + return self.memory[key] + cfg = FallbackConfigEntity() + cfg.template_key = 'direct' + self.memory[key] = cfg + return cfg + + +def test_conv2d_nhwc(): + autotvm.DispatchContext.current.silent = True + with DefaultFallback(): + verify_conv2d_1x1_nhwc_pack_int8(1, 256, 32, 256, 1, 1, 0) + + +if __name__ == "__main__": + test_conv2d_nhwc() From ef573a5834277ac954de48702d8892f0f2c39f6f Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 15:09:00 -0700 Subject: [PATCH 02/17] memoize the test result --- .../python/test_topi_conv2d_NCHWc_int8.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 topi/tests/python/test_topi_conv2d_NCHWc_int8.py diff --git a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py new file mode 100644 index 0000000000000..2cf32dbc040e2 --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py @@ -0,0 +1,71 @@ +"""Example code to do convolution.""" +import os +import numpy as np +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import FallbackConfigEntity +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + + +def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter, kernel, kernel), name='W', dtype='int8') + + a_shape = get_const_tuple(A.shape) + w_shape = (kernel, kernel, in_channel, num_filter) + dtype = A.dtype + + #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(W.dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + + with tvm.target.create(device): + B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") + s = topi.generic.schedule_conv2d_nhwc([B]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], device) + func(a, w, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm -mcpu=skylake-avx512']: + check_device(device) + + +class DefaultFallback(autotvm.FallbackContext): + def _query_inside(self, target, workload): + key = (target, workload) + if key in self.memory: + return self.memory[key] + cfg = FallbackConfigEntity() + cfg.template_key = 'direct' + self.memory[key] = cfg + return cfg + + +def test_conv2d_NCHWc(): + autotvm.DispatchContext.current.silent = True + with DefaultFallback(): + verify_conv2d_NCHWc_int8(1, 256, 32, 256, 7, 7, 0) + + +if __name__ == "__main__": + test_conv2d_NCHWc() From 88909e37472f119c0109c54044a35dc7224ec330 Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 15:14:49 -0700 Subject: [PATCH 03/17] fix the wrong file --- .../python/test_topi_conv2d_NCHWc_int8.py | 71 ------------------- .../python/test_topi_conv2d_nhwc_pack_int8.py | 2 +- 2 files changed, 1 insertion(+), 72 deletions(-) delete mode 100644 topi/tests/python/test_topi_conv2d_NCHWc_int8.py diff --git a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py deleted file mode 100644 index 2cf32dbc040e2..0000000000000 --- a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Example code to do convolution.""" -import os -import numpy as np -import tvm -from tvm import autotvm -from tvm.autotvm.task.space import FallbackConfigEntity -import topi -import topi.testing -from tvm.contrib.pickle_memoize import memoize -from topi.util import get_const_tuple - - -def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): - in_height = in_width = in_size - - A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') - W = tvm.placeholder((kernel, kernel, in_channel, num_filter, kernel, kernel), name='W', dtype='int8') - - a_shape = get_const_tuple(A.shape) - w_shape = (kernel, kernel, in_channel, num_filter) - dtype = A.dtype - - #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(W.dtype) - dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) - b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) - return a_np, w_np, b_np - a_np, w_np, b_np = get_ref_data() - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - - with tvm.target.create(device): - B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") - s = topi.generic.schedule_conv2d_nhwc([B]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - func = tvm.build(s, [A, W, B], device) - func(a, w, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['llvm -mcpu=skylake-avx512']: - check_device(device) - - -class DefaultFallback(autotvm.FallbackContext): - def _query_inside(self, target, workload): - key = (target, workload) - if key in self.memory: - return self.memory[key] - cfg = FallbackConfigEntity() - cfg.template_key = 'direct' - self.memory[key] = cfg - return cfg - - -def test_conv2d_NCHWc(): - autotvm.DispatchContext.current.silent = True - with DefaultFallback(): - verify_conv2d_NCHWc_int8(1, 256, 32, 256, 7, 7, 0) - - -if __name__ == "__main__": - test_conv2d_NCHWc() diff --git a/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py index 441d9880f6ffe..2171e31d9f899 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py @@ -20,7 +20,7 @@ def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, ker w_shape = (kernel, kernel, in_channel, num_filter) dtype = A.dtype - #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") + @memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(W.dtype) From 7e05f0bddc6005e0530f02a8c254f155311dd6e1 Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Sat, 6 Apr 2019 10:55:30 -0700 Subject: [PATCH 04/17] Add the int8 group conv support on x86 --- topi/python/topi/x86/conv2d.py | 23 ++++- .../test_topi_group_conv2d_NCHWc_int8.py | 93 +++++++++++++++++++ 2 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 9934843286222..1cd28ddda0234 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -416,10 +416,11 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn if data.dtype == 'uint8': - oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) else: - oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn + groups = ic_chunk // ic_chunk_group if cfg.is_fallback: _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), @@ -443,7 +444,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') - if data.dtype == 'uint8': + if data.dtype == 'uint8' and groups == 1: assert out_dtype == "int32", \ "INT8 convolution requires input dtype = uint8 and output dtype=int32" # Intel performs dot product of 2 "4" Int8 values @@ -462,6 +463,22 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, oc_block, ic_s_inner].astype(out_dtype), axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + elif data.dtype == 'uint8': + # for int8 group conv support + n_elems = 4 + ic_chunk = in_channel//ic_bn + ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer') + ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') + ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block: + tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*(ic_chunk//groups)+ic_outer, oh*HSTR+kh, ow*WSTR+kw, + ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * + kernel[occ, ic_outer, kh, kw, ic_f_inner, + oc_block, ic_s_inner].astype(out_dtype), + axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), + name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + # else: fp implementation return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, diff --git a/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py new file mode 100644 index 0000000000000..34a437838cc67 --- /dev/null +++ b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py @@ -0,0 +1,93 @@ +"""Test for NCHW[x]c convolution""" + +import numpy as np +import tvm +from tvm import autotvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + +from common import get_all_backend + +def _transform_data(data, bn): + # NCHW -> NCHW[x]c + batch_size, channel, height, width = data.shape + data = np.reshape(data, (batch_size, channel//bn, bn, height, width)) + data = np.transpose(data, (0, 1, 3, 4, 2)) + return data + +def _transform_kernel(kernel, ic_bn, oc_bn): + # OIHW -> OIHW[x]i[x]o + out_channel, in_channel, kh, kw = kernel.shape + kernel = np.reshape(kernel, (out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn//4, kh, kw, 4)) + kernel = np.transpose(kernel, (0, 2, 4, 5, 3, 1, 6)) + return kernel + +def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filter, kernel, stride, + padding, dilation=1, add_bias=False, add_relu=False, dtype="int32"): + assert dilation == 1, "conv2d_NCHWc does not support dilation for now." + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % + (batch, in_channel, groups, in_size, num_filter, kernel, stride, padding)) + + in_height = in_width = in_size + + # for testing functionality, + # we choose arbitrary block size that can divide the channel, + # regardless of the performance. + oc_block = 1 + for bn in range(16, 0, -1): + if num_filter % bn == 0: + oc_block = bn + break + + ic_block = 8 + autotvm.DispatchContext.current.silent = True + A = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A', dtype='uint8') + W = tvm.placeholder((num_filter//oc_block, in_channel//ic_block//groups, kernel, kernel, ic_block//4, oc_block, 4), name='W', dtype='int8') + + @memoize("topi.tests.test_topi_conv2d_NCHWc_int8.verify_conv2d_NCHWc_int8") + def get_ref_data(): + a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype("uint8") + w_np = np.random.uniform(size=(num_filter, in_channel//groups, kernel, kernel)).astype("int8") + c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding, groups) + return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \ + _transform_data(c_np, oc_block) + + a_np, w_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding), + (dilation, dilation), + layout='NCHW%dc'%ic_block, + out_layout="NCHW%dc"%oc_block, + out_dtype=dtype) + s = topi.generic.schedule_conv2d_NCHWc([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + func = tvm.build(s, [A, W, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % + (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + # print(tvm.lower(s, [A, W, C], simple_mode=True)) + func(a, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3) + + for device in ["llvm -mcpu=skylake-avx512"]: + with autotvm.tophub.context(device): # load tophub pre-tuned parameters + check_device(device) + + +def test_conv2d_NCHWc(): + # ResNet50 workloads + verify_group_conv2d_NCHWc_int8(1, 256, 32, 224, 64, 7, 2, 3) + +if __name__ == "__main__": + test_conv2d_NCHWc() From 2fc95feef31656c287f3d77628565bf6e21f513e Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 13:54:17 -0700 Subject: [PATCH 05/17] Support the 1x1 int8 conv with NHWC layout and weight packing --- topi/python/topi/nn/conv2d.py | 24 +++- topi/python/topi/x86/conv2d.py | 49 ++++++-- topi/python/topi/x86/conv2d_avx_1x1.py | 106 +++++++++++++++++- .../python/test_topi_conv2d_nhwc_pack_int8.py | 71 ++++++++++++ 4 files changed, 236 insertions(+), 14 deletions(-) create mode 100644 topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 49c0bd79eacc5..6b739ee030577 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -28,7 +28,7 @@ # workload description of conv2d Workload = namedtuple('Workload', - ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter', + ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups', 'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) @tvm.target.generic_func @@ -95,11 +95,25 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F): return None -def _get_workload(data, kernel, stride, padding, out_dtype): +def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): """ Get the workload structure. """ - _, CI, IH, IW = [x.value for x in data.shape] - CO, _, KH, KW = [x.value for x in kernel.shape] + if data_layout == 'NCHW': + _, CI, IH, IW = [x.value for x in data.shape] + elif data_layout == 'NHWC': + _, IH, IW, CI = [x.value for x in data.shape] + elif data_layout == 'HWCN': + IH, IW, CI, _ = [x.value for x in data.shape] + else: + raise ValueError("not support this layout {} yet".format(data_layout)) + + + if data_layout == 'NHWC': + KH, KW, CO, CIG = [x.value for x in kernel.shape] + else: + CO, CIG, KH, KW = [x.value for x in kernel.shape] + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + GRPS = CI // CIG if isinstance(stride, (tuple, list)): HSTR, WSTR = stride else: @@ -107,7 +121,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype): assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \ "Do not support inputs with different data types now. ' \ '{} vs. {}".format(data.dtype, kernel.dtype) - return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) + return Workload(data.dtype, out_dtype, IH, IW, CI, GRPS, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index f46c948bdeb10..7333978b5351d 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -37,7 +37,7 @@ logger = logging.getLogger('topi') -def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): +def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout='NCHW'): """ Get default schedule config for the workload """ @@ -46,7 +46,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth from .depthwise_conv2d import _fallback_schedule _fallback_schedule(cfg, wkl) else: - wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) + wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 if is_kernel_1x1: conv2d_avx_1x1._fallback_schedule(cfg, wkl) @@ -62,6 +62,8 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): if layout == 'NCHW': n, ic, h, w = dshape oc, _, kh, kw = kshape + elif layout == 'NHWC': + n, h, w, ic = dshape elif pat.match(layout) is not None: n, ic_chunk, h, w, ic_bn = dshape if data.dtype == 'uint8': @@ -93,12 +95,14 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): cfg.define_knob("unroll_kw", [True, False]) -@autotvm.register_topi_compute(conv2d, 'cpu', 'direct') +@autotvm.register_topi_compute(conv2d, 'cpu', ['direct']) def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): out_dtype = data.dtype if out_dtype is None else out_dtype padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + + _, _, kh, kw = get_const_tuple(kernel.shape) if layout == 'NCHW': _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout) if cfg.is_fallback: @@ -107,7 +111,13 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out padding, dilation, layout, out_dtype) if layout == 'HWCN': return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype) - if layout == 'NHWC': + elif layout == 'NHWC' and kh == 1 and kw == 1 and kernel.dtype == "int8": + if cfg.is_fallback: + _get_default_config(cfg, data, kernel, strides, padding, out_dtype, False, layout) + # specialize for INT8 1X1 conv on X86 + return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides, + padding, dilation, out_dtype) + elif layout == 'NHWC': return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) raise ValueError("not support this layout {} yet".format(layout)) @@ -226,8 +236,9 @@ def traverse(op): return s -@generic.schedule_conv2d_nhwc.register("cpu") -def schedule_conv2d_nhwc(outs): +# @generic.schedule_conv2d_nhwc.register("cpu") +@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc, 'cpu', ['direct']) +def schedule_conv2d_nhwc(cfg, outs): """Create schedule for tensors""" s = tvm.create_schedule([x.op for x in outs]) output_op = outs[0].op @@ -249,7 +260,31 @@ def traverse(op): if tensor.op.input_tensors and tensor.op not in scheduled_ops: traverse(tensor.op) - if 'conv2d_nhwc' in op.tag: + if 'conv2d_nhwc_pack_int8' in op.tag: + conv_out = op.output(0) + kernel = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = data_vec.op.input_tensors[0] \ + if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ + else data_vec + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + args = [s, cfg, data_vec, conv_out, outs[0]] + if data.dtype == 'uint8': + # int8 conv kernel is 7-dim + kh, kw, _, _, _ = get_const_tuple(kernel.shape) + if kh == 1 and kw == 1: + conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args) + else: + raise ValueError("Only support 1x1 kernel with " + "schedule template.") + else: + raise ValueError("Not support this data type {} with " + "schedule template.".format(data.dtype)) + + elif 'conv2d_nhwc' in op.tag: conv = op.output(0) kernel = op.input_tensors[1] if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index bcd2cefc2bdf2..29bb0802e3c59 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -20,8 +20,9 @@ import tvm from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity -from ..nn.util import infer_pad -from ..util import get_const_tuple +from ..nn.pad import pad +from ..nn.util import infer_pad, get_pad_tuple +from ..util import get_const_tuple, simplify from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .check_targets import check_skylake from .util import get_fp32_len @@ -251,3 +252,104 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): s[O].parallel(parallel_axis) return s + + +def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype): + # more assertion for the shapes + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = Input.shape + kernel_h, kernel_w, num_filter, channel = Filter.shape + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + out_channel = num_filter + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + # todo: padding filter to accomodate the intrinsic + + # packing the Filter to let memory access be consecutive for AVX512 intrinsic + # Done in pre-compute stage + packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4) + PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e], name="packed_filter") + + rc = tvm.reduce_axis((0, in_channel), name='rc') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + Output = tvm.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: tvm.sum( + PaddedInput[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + PackW[ry, rx, ff/16, (rc/4)*16+ff%16, rc%4].astype(out_dtype), axis=[ry, rx, rc]), + name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8") + return Output + + +def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last): + """ + Defines the schedule for the int8 nhwc layout. For 1x1 conv, it + is a matrix-multiply operation by using nhwc layout. We will do + packing of weight to make the address access be friendly to int8 + intrinsic + """ + target = tvm.target.current_target(allow_none=False) + int32_lanes = -1 + if check_skylake(target): + int32_lanes = 16 + else: + return s + assert int32_lanes != -1 + + # assertion to fail the unhandled case + _, _, _, ic_num = get_const_tuple(data.shape) + _, _, _, oc_num = get_const_tuple(conv_out.shape) + assert ic_num % 4 == 0 + assert oc_num % 16 == 0 + + ic_factor, oc_factor = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + # schedule data + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ih, iw, ic = s[A].op.axis + d_ic_chunk, d_ic_block = s[A].split(ic, factor=4) + s[A].vectorize(d_ic_block) + + C, O = conv_out, last + + batch, oh, ow, oc = s[C].op.axis + kh, kw, ic = s[C].op.reduce_axis + # match the x86 intrinsic + ic_outer, ic_inner = s[C].split(ic, factor=4) + oc_outer, oc_inner = s[C].split(oc, factor=int32_lanes) + + ic_f_outer, ic_s_outer = s[C].split(ic_outer, factor=ic_factor) + s[C].reorder(oc_outer, oh, ow, ic_f_outer, ic_s_outer, kh, kw, oc_inner, ic_inner) + + pc = dot_16x1x16_int8_int8_int32() + s[C].tensorize(oc_inner, pc) + + if C != O: + batch, last_oh, last_ow, last_oc = s[O].op.axis + oc_chunk, oc_block = s[O].split(ochannel, 16) + # not saw perf improvement to split oh/ow here + s[O].vectorize(oc_block) + + return s + diff --git a/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py new file mode 100644 index 0000000000000..441d9880f6ffe --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py @@ -0,0 +1,71 @@ +"""Example code to do convolution.""" +import os +import numpy as np +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import FallbackConfigEntity +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + + +def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8') + + a_shape = get_const_tuple(A.shape) + w_shape = (kernel, kernel, in_channel, num_filter) + dtype = A.dtype + + #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(W.dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + + with tvm.target.create(device): + B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") + s = topi.generic.schedule_conv2d_nhwc([B]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], device) + func(a, w, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm -mcpu=skylake-avx512']: + check_device(device) + + +class DefaultFallback(autotvm.FallbackContext): + def _query_inside(self, target, workload): + key = (target, workload) + if key in self.memory: + return self.memory[key] + cfg = FallbackConfigEntity() + cfg.template_key = 'direct' + self.memory[key] = cfg + return cfg + + +def test_conv2d_nhwc(): + autotvm.DispatchContext.current.silent = True + with DefaultFallback(): + verify_conv2d_1x1_nhwc_pack_int8(1, 256, 32, 256, 1, 1, 0) + + +if __name__ == "__main__": + test_conv2d_nhwc() From 5c19271b5f0aee68baa67498b5b3ee9fb1ec218a Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 15:09:00 -0700 Subject: [PATCH 06/17] memoize the test result --- .../python/test_topi_conv2d_NCHWc_int8.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 topi/tests/python/test_topi_conv2d_NCHWc_int8.py diff --git a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py new file mode 100644 index 0000000000000..2cf32dbc040e2 --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py @@ -0,0 +1,71 @@ +"""Example code to do convolution.""" +import os +import numpy as np +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import FallbackConfigEntity +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + + +def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter, kernel, kernel), name='W', dtype='int8') + + a_shape = get_const_tuple(A.shape) + w_shape = (kernel, kernel, in_channel, num_filter) + dtype = A.dtype + + #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(W.dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + + with tvm.target.create(device): + B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") + s = topi.generic.schedule_conv2d_nhwc([B]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], device) + func(a, w, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm -mcpu=skylake-avx512']: + check_device(device) + + +class DefaultFallback(autotvm.FallbackContext): + def _query_inside(self, target, workload): + key = (target, workload) + if key in self.memory: + return self.memory[key] + cfg = FallbackConfigEntity() + cfg.template_key = 'direct' + self.memory[key] = cfg + return cfg + + +def test_conv2d_NCHWc(): + autotvm.DispatchContext.current.silent = True + with DefaultFallback(): + verify_conv2d_NCHWc_int8(1, 256, 32, 256, 7, 7, 0) + + +if __name__ == "__main__": + test_conv2d_NCHWc() From 855a19c317b6094f492f05400efdd756c763846d Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 15:14:49 -0700 Subject: [PATCH 07/17] fix the wrong file --- .../python/test_topi_conv2d_NCHWc_int8.py | 71 ------------------- .../python/test_topi_conv2d_nhwc_pack_int8.py | 2 +- 2 files changed, 1 insertion(+), 72 deletions(-) delete mode 100644 topi/tests/python/test_topi_conv2d_NCHWc_int8.py diff --git a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py deleted file mode 100644 index 2cf32dbc040e2..0000000000000 --- a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Example code to do convolution.""" -import os -import numpy as np -import tvm -from tvm import autotvm -from tvm.autotvm.task.space import FallbackConfigEntity -import topi -import topi.testing -from tvm.contrib.pickle_memoize import memoize -from topi.util import get_const_tuple - - -def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): - in_height = in_width = in_size - - A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') - W = tvm.placeholder((kernel, kernel, in_channel, num_filter, kernel, kernel), name='W', dtype='int8') - - a_shape = get_const_tuple(A.shape) - w_shape = (kernel, kernel, in_channel, num_filter) - dtype = A.dtype - - #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(W.dtype) - dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) - b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) - return a_np, w_np, b_np - a_np, w_np, b_np = get_ref_data() - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - - with tvm.target.create(device): - B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") - s = topi.generic.schedule_conv2d_nhwc([B]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - func = tvm.build(s, [A, W, B], device) - func(a, w, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['llvm -mcpu=skylake-avx512']: - check_device(device) - - -class DefaultFallback(autotvm.FallbackContext): - def _query_inside(self, target, workload): - key = (target, workload) - if key in self.memory: - return self.memory[key] - cfg = FallbackConfigEntity() - cfg.template_key = 'direct' - self.memory[key] = cfg - return cfg - - -def test_conv2d_NCHWc(): - autotvm.DispatchContext.current.silent = True - with DefaultFallback(): - verify_conv2d_NCHWc_int8(1, 256, 32, 256, 7, 7, 0) - - -if __name__ == "__main__": - test_conv2d_NCHWc() diff --git a/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py index 441d9880f6ffe..2171e31d9f899 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py @@ -20,7 +20,7 @@ def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, ker w_shape = (kernel, kernel, in_channel, num_filter) dtype = A.dtype - #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") + @memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(W.dtype) From 26d75a6d7f110da17f818b438930593b098495eb Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Sat, 6 Apr 2019 10:55:30 -0700 Subject: [PATCH 08/17] Add the int8 group conv support on x86 --- topi/python/topi/x86/conv2d.py | 23 ++++- .../test_topi_group_conv2d_NCHWc_int8.py | 93 +++++++++++++++++++ 2 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 7333978b5351d..7e8f52cc03f59 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -452,10 +452,11 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn if data.dtype == 'uint8': - oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) else: - oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn + groups = ic_chunk // ic_chunk_group if cfg.is_fallback: _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), @@ -479,7 +480,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') - if data.dtype == 'uint8': + if data.dtype == 'uint8' and groups == 1: assert out_dtype == "int32", \ "INT8 convolution requires input dtype = uint8 and output dtype=int32" # Intel performs dot product of 2 "4" Int8 values @@ -498,6 +499,22 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, oc_block, ic_s_inner].astype(out_dtype), axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + elif data.dtype == 'uint8': + # for int8 group conv support + n_elems = 4 + ic_chunk = in_channel//ic_bn + ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer') + ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') + ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block: + tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*(ic_chunk//groups)+ic_outer, oh*HSTR+kh, ow*WSTR+kw, + ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * + kernel[occ, ic_outer, kh, kw, ic_f_inner, + oc_block, ic_s_inner].astype(out_dtype), + axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), + name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + # else: fp implementation return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, diff --git a/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py new file mode 100644 index 0000000000000..34a437838cc67 --- /dev/null +++ b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py @@ -0,0 +1,93 @@ +"""Test for NCHW[x]c convolution""" + +import numpy as np +import tvm +from tvm import autotvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + +from common import get_all_backend + +def _transform_data(data, bn): + # NCHW -> NCHW[x]c + batch_size, channel, height, width = data.shape + data = np.reshape(data, (batch_size, channel//bn, bn, height, width)) + data = np.transpose(data, (0, 1, 3, 4, 2)) + return data + +def _transform_kernel(kernel, ic_bn, oc_bn): + # OIHW -> OIHW[x]i[x]o + out_channel, in_channel, kh, kw = kernel.shape + kernel = np.reshape(kernel, (out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn//4, kh, kw, 4)) + kernel = np.transpose(kernel, (0, 2, 4, 5, 3, 1, 6)) + return kernel + +def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filter, kernel, stride, + padding, dilation=1, add_bias=False, add_relu=False, dtype="int32"): + assert dilation == 1, "conv2d_NCHWc does not support dilation for now." + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % + (batch, in_channel, groups, in_size, num_filter, kernel, stride, padding)) + + in_height = in_width = in_size + + # for testing functionality, + # we choose arbitrary block size that can divide the channel, + # regardless of the performance. + oc_block = 1 + for bn in range(16, 0, -1): + if num_filter % bn == 0: + oc_block = bn + break + + ic_block = 8 + autotvm.DispatchContext.current.silent = True + A = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A', dtype='uint8') + W = tvm.placeholder((num_filter//oc_block, in_channel//ic_block//groups, kernel, kernel, ic_block//4, oc_block, 4), name='W', dtype='int8') + + @memoize("topi.tests.test_topi_conv2d_NCHWc_int8.verify_conv2d_NCHWc_int8") + def get_ref_data(): + a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype("uint8") + w_np = np.random.uniform(size=(num_filter, in_channel//groups, kernel, kernel)).astype("int8") + c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding, groups) + return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \ + _transform_data(c_np, oc_block) + + a_np, w_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding), + (dilation, dilation), + layout='NCHW%dc'%ic_block, + out_layout="NCHW%dc"%oc_block, + out_dtype=dtype) + s = topi.generic.schedule_conv2d_NCHWc([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + func = tvm.build(s, [A, W, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % + (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + # print(tvm.lower(s, [A, W, C], simple_mode=True)) + func(a, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3) + + for device in ["llvm -mcpu=skylake-avx512"]: + with autotvm.tophub.context(device): # load tophub pre-tuned parameters + check_device(device) + + +def test_conv2d_NCHWc(): + # ResNet50 workloads + verify_group_conv2d_NCHWc_int8(1, 256, 32, 224, 64, 7, 2, 3) + +if __name__ == "__main__": + test_conv2d_NCHWc() From 03e4b38ebd738c31576589d772e8aa0fda3d1e4a Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 13:54:17 -0700 Subject: [PATCH 09/17] Support the 1x1 int8 conv with NHWC layout and weight packing --- topi/python/topi/nn/conv2d.py | 24 +++- topi/python/topi/x86/conv2d.py | 50 +++++++-- topi/python/topi/x86/conv2d_avx_1x1.py | 106 +++++++++++++++++- .../python/test_topi_conv2d_nhwc_pack_int8.py | 71 ++++++++++++ 4 files changed, 237 insertions(+), 14 deletions(-) create mode 100644 topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 49c0bd79eacc5..6b739ee030577 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -28,7 +28,7 @@ # workload description of conv2d Workload = namedtuple('Workload', - ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'out_filter', + ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups', 'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) @tvm.target.generic_func @@ -95,11 +95,25 @@ def conv2d_alter_layout(attrs, inputs, tinfos, F): return None -def _get_workload(data, kernel, stride, padding, out_dtype): +def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): """ Get the workload structure. """ - _, CI, IH, IW = [x.value for x in data.shape] - CO, _, KH, KW = [x.value for x in kernel.shape] + if data_layout == 'NCHW': + _, CI, IH, IW = [x.value for x in data.shape] + elif data_layout == 'NHWC': + _, IH, IW, CI = [x.value for x in data.shape] + elif data_layout == 'HWCN': + IH, IW, CI, _ = [x.value for x in data.shape] + else: + raise ValueError("not support this layout {} yet".format(data_layout)) + + + if data_layout == 'NHWC': + KH, KW, CO, CIG = [x.value for x in kernel.shape] + else: + CO, CIG, KH, KW = [x.value for x in kernel.shape] + HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) + GRPS = CI // CIG if isinstance(stride, (tuple, list)): HSTR, WSTR = stride else: @@ -107,7 +121,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype): assert (data.dtype == kernel.dtype) or (data.dtype == 'uint8' and kernel.dtype == 'int8'), \ "Do not support inputs with different data types now. ' \ '{} vs. {}".format(data.dtype, kernel.dtype) - return Workload(data.dtype, out_dtype, IH, IW, CI, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) + return Workload(data.dtype, out_dtype, IH, IW, CI, GRPS, CO, KH, KW, HPAD, WPAD, HSTR, WSTR) def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index f46c948bdeb10..bc2a90bd92749 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -37,7 +37,7 @@ logger = logging.getLogger('topi') -def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False): +def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout='NCHW'): """ Get default schedule config for the workload """ @@ -46,7 +46,7 @@ def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depth from .depthwise_conv2d import _fallback_schedule _fallback_schedule(cfg, wkl) else: - wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype) + wkl = _get_conv2d_workload(data, kernel, strides, padding, out_dtype, layout) is_kernel_1x1 = wkl.hkernel == 1 and wkl.wkernel == 1 if is_kernel_1x1: conv2d_avx_1x1._fallback_schedule(cfg, wkl) @@ -62,6 +62,9 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): if layout == 'NCHW': n, ic, h, w = dshape oc, _, kh, kw = kshape + elif layout == 'NHWC': + n, h, w, ic = dshape + oc, _, kh, kw = kshape elif pat.match(layout) is not None: n, ic_chunk, h, w, ic_bn = dshape if data.dtype == 'uint8': @@ -93,12 +96,14 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): cfg.define_knob("unroll_kw", [True, False]) -@autotvm.register_topi_compute(conv2d, 'cpu', 'direct') +@autotvm.register_topi_compute(conv2d, 'cpu', ['direct']) def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): out_dtype = data.dtype if out_dtype is None else out_dtype padding = padding if isinstance(padding, (tuple, list)) else (padding, padding) strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + + _, _, kh, kw = get_const_tuple(kernel.shape) if layout == 'NCHW': _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout) if cfg.is_fallback: @@ -107,7 +112,13 @@ def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out padding, dilation, layout, out_dtype) if layout == 'HWCN': return nn.conv2d_hwcn(data, kernel, strides, padding, dilation, out_dtype) - if layout == 'NHWC': + elif layout == 'NHWC' and kh == 1 and kw == 1 and kernel.dtype == "int8": + if cfg.is_fallback: + _get_default_config(cfg, data, kernel, strides, padding, out_dtype, False, layout) + # specialize for INT8 1X1 conv on X86 + return conv2d_avx_1x1._declaration_conv_nhwc_pack(cfg, data, kernel, strides, + padding, dilation, out_dtype) + elif layout == 'NHWC': return nn.conv2d_nhwc(data, kernel, strides, padding, dilation, out_dtype) raise ValueError("not support this layout {} yet".format(layout)) @@ -226,8 +237,9 @@ def traverse(op): return s -@generic.schedule_conv2d_nhwc.register("cpu") -def schedule_conv2d_nhwc(outs): +# @generic.schedule_conv2d_nhwc.register("cpu") +@autotvm.register_topi_schedule(generic.schedule_conv2d_nhwc, 'cpu', ['direct']) +def schedule_conv2d_nhwc(cfg, outs): """Create schedule for tensors""" s = tvm.create_schedule([x.op for x in outs]) output_op = outs[0].op @@ -249,7 +261,31 @@ def traverse(op): if tensor.op.input_tensors and tensor.op not in scheduled_ops: traverse(tensor.op) - if 'conv2d_nhwc' in op.tag: + if 'conv2d_nhwc_pack_int8' in op.tag: + conv_out = op.output(0) + kernel = conv_out.op.input_tensors[1] + data_vec = conv_out.op.input_tensors[0] + data = data_vec.op.input_tensors[0] \ + if isinstance(data_vec.op, tvm.tensor.ComputeOp) and "pad" not in data_vec.op.tag \ + else data_vec + if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag: + data_pad = data + data = data_pad.op.input_tensors[0] + + args = [s, cfg, data_vec, conv_out, outs[0]] + if data.dtype == 'uint8': + # int8 conv kernel is 7-dim + kh, kw, _, _, _ = get_const_tuple(kernel.shape) + if kh == 1 and kw == 1: + conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args) + else: + raise ValueError("Only support 1x1 kernel with " + "schedule template.") + else: + raise ValueError("Not support this data type {} with " + "schedule template.".format(data.dtype)) + + elif 'conv2d_nhwc' in op.tag: conv = op.output(0) kernel = op.input_tensors[1] if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag: diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index bcd2cefc2bdf2..29bb0802e3c59 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -20,8 +20,9 @@ import tvm from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity -from ..nn.util import infer_pad -from ..util import get_const_tuple +from ..nn.pad import pad +from ..nn.util import infer_pad, get_pad_tuple +from ..util import get_const_tuple, simplify from .tensor_intrin import dot_16x1x16_int8_int8_int32 from .check_targets import check_skylake from .util import get_fp32_len @@ -251,3 +252,104 @@ def _schedule_conv_NCHWc_int8(s, cfg, data, conv_out, last): s[O].parallel(parallel_axis) return s + + +def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, out_dtype): + # more assertion for the shapes + assert isinstance(stride, int) or len(stride) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + if isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + + batch, in_height, in_width, in_channel = Input.shape + kernel_h, kernel_w, num_filter, channel = Filter.shape + + # compute the output shape + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w)) + out_channel = num_filter + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + pad_before = [0, pad_top, pad_left, 0] + pad_after = [0, pad_down, pad_right, 0] + PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput") + # todo: padding filter to accomodate the intrinsic + + # packing the Filter to let memory access be consecutive for AVX512 intrinsic + # Done in pre-compute stage + packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4) + PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e], name="packed_filter") + + rc = tvm.reduce_axis((0, in_channel), name='rc') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + Output = tvm.compute( + (batch, out_height, out_width, out_channel), + lambda nn, yy, xx, ff: tvm.sum( + PaddedInput[nn, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * + PackW[ry, rx, ff/16, (rc/4)*16+ff%16, rc%4].astype(out_dtype), axis=[ry, rx, rc]), + name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8") + return Output + + +def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last): + """ + Defines the schedule for the int8 nhwc layout. For 1x1 conv, it + is a matrix-multiply operation by using nhwc layout. We will do + packing of weight to make the address access be friendly to int8 + intrinsic + """ + target = tvm.target.current_target(allow_none=False) + int32_lanes = -1 + if check_skylake(target): + int32_lanes = 16 + else: + return s + assert int32_lanes != -1 + + # assertion to fail the unhandled case + _, _, _, ic_num = get_const_tuple(data.shape) + _, _, _, oc_num = get_const_tuple(conv_out.shape) + assert ic_num % 4 == 0 + assert oc_num % 16 == 0 + + ic_factor, oc_factor = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] + # schedule data + A = data + if isinstance(s[A].op, tvm.tensor.ComputeOp): + batch, ih, iw, ic = s[A].op.axis + d_ic_chunk, d_ic_block = s[A].split(ic, factor=4) + s[A].vectorize(d_ic_block) + + C, O = conv_out, last + + batch, oh, ow, oc = s[C].op.axis + kh, kw, ic = s[C].op.reduce_axis + # match the x86 intrinsic + ic_outer, ic_inner = s[C].split(ic, factor=4) + oc_outer, oc_inner = s[C].split(oc, factor=int32_lanes) + + ic_f_outer, ic_s_outer = s[C].split(ic_outer, factor=ic_factor) + s[C].reorder(oc_outer, oh, ow, ic_f_outer, ic_s_outer, kh, kw, oc_inner, ic_inner) + + pc = dot_16x1x16_int8_int8_int32() + s[C].tensorize(oc_inner, pc) + + if C != O: + batch, last_oh, last_ow, last_oc = s[O].op.axis + oc_chunk, oc_block = s[O].split(ochannel, 16) + # not saw perf improvement to split oh/ow here + s[O].vectorize(oc_block) + + return s + diff --git a/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py new file mode 100644 index 0000000000000..441d9880f6ffe --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py @@ -0,0 +1,71 @@ +"""Example code to do convolution.""" +import os +import numpy as np +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import FallbackConfigEntity +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + + +def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W', dtype='int8') + + a_shape = get_const_tuple(A.shape) + w_shape = (kernel, kernel, in_channel, num_filter) + dtype = A.dtype + + #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(W.dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + + with tvm.target.create(device): + B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") + s = topi.generic.schedule_conv2d_nhwc([B]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], device) + func(a, w, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm -mcpu=skylake-avx512']: + check_device(device) + + +class DefaultFallback(autotvm.FallbackContext): + def _query_inside(self, target, workload): + key = (target, workload) + if key in self.memory: + return self.memory[key] + cfg = FallbackConfigEntity() + cfg.template_key = 'direct' + self.memory[key] = cfg + return cfg + + +def test_conv2d_nhwc(): + autotvm.DispatchContext.current.silent = True + with DefaultFallback(): + verify_conv2d_1x1_nhwc_pack_int8(1, 256, 32, 256, 1, 1, 0) + + +if __name__ == "__main__": + test_conv2d_nhwc() From 0b05a63c05a374ce076e257d52a8b4656b5990e7 Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 15:09:00 -0700 Subject: [PATCH 10/17] memoize the test result --- .../python/test_topi_conv2d_NCHWc_int8.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 topi/tests/python/test_topi_conv2d_NCHWc_int8.py diff --git a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py new file mode 100644 index 0000000000000..2cf32dbc040e2 --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py @@ -0,0 +1,71 @@ +"""Example code to do convolution.""" +import os +import numpy as np +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import FallbackConfigEntity +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + + +def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter, kernel, kernel), name='W', dtype='int8') + + a_shape = get_const_tuple(A.shape) + w_shape = (kernel, kernel, in_channel, num_filter) + dtype = A.dtype + + #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(W.dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + + with tvm.target.create(device): + B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") + s = topi.generic.schedule_conv2d_nhwc([B]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], device) + func(a, w, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm -mcpu=skylake-avx512']: + check_device(device) + + +class DefaultFallback(autotvm.FallbackContext): + def _query_inside(self, target, workload): + key = (target, workload) + if key in self.memory: + return self.memory[key] + cfg = FallbackConfigEntity() + cfg.template_key = 'direct' + self.memory[key] = cfg + return cfg + + +def test_conv2d_NCHWc(): + autotvm.DispatchContext.current.silent = True + with DefaultFallback(): + verify_conv2d_NCHWc_int8(1, 256, 32, 256, 7, 7, 0) + + +if __name__ == "__main__": + test_conv2d_NCHWc() From b53b61abb0aa890659428d12bb4f02eea9ff2817 Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 15:14:49 -0700 Subject: [PATCH 11/17] fix the wrong file --- .../python/test_topi_conv2d_NCHWc_int8.py | 71 ------------------- .../python/test_topi_conv2d_nhwc_pack_int8.py | 2 +- 2 files changed, 1 insertion(+), 72 deletions(-) delete mode 100644 topi/tests/python/test_topi_conv2d_NCHWc_int8.py diff --git a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py deleted file mode 100644 index 2cf32dbc040e2..0000000000000 --- a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Example code to do convolution.""" -import os -import numpy as np -import tvm -from tvm import autotvm -from tvm.autotvm.task.space import FallbackConfigEntity -import topi -import topi.testing -from tvm.contrib.pickle_memoize import memoize -from topi.util import get_const_tuple - - -def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): - in_height = in_width = in_size - - A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') - W = tvm.placeholder((kernel, kernel, in_channel, num_filter, kernel, kernel), name='W', dtype='int8') - - a_shape = get_const_tuple(A.shape) - w_shape = (kernel, kernel, in_channel, num_filter) - dtype = A.dtype - - #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(W.dtype) - dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) - b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) - return a_np, w_np, b_np - a_np, w_np, b_np = get_ref_data() - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - - with tvm.target.create(device): - B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") - s = topi.generic.schedule_conv2d_nhwc([B]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - func = tvm.build(s, [A, W, B], device) - func(a, w, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['llvm -mcpu=skylake-avx512']: - check_device(device) - - -class DefaultFallback(autotvm.FallbackContext): - def _query_inside(self, target, workload): - key = (target, workload) - if key in self.memory: - return self.memory[key] - cfg = FallbackConfigEntity() - cfg.template_key = 'direct' - self.memory[key] = cfg - return cfg - - -def test_conv2d_NCHWc(): - autotvm.DispatchContext.current.silent = True - with DefaultFallback(): - verify_conv2d_NCHWc_int8(1, 256, 32, 256, 7, 7, 0) - - -if __name__ == "__main__": - test_conv2d_NCHWc() diff --git a/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py index 441d9880f6ffe..2171e31d9f899 100644 --- a/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py +++ b/topi/tests/python/test_topi_conv2d_nhwc_pack_int8.py @@ -20,7 +20,7 @@ def verify_conv2d_1x1_nhwc_pack_int8(batch, in_channel, in_size, num_filter, ker w_shape = (kernel, kernel, in_channel, num_filter) dtype = A.dtype - #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") + @memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") def get_ref_data(): a_np = np.random.uniform(size=a_shape).astype(dtype) w_np = np.random.uniform(size=w_shape).astype(W.dtype) From d5ef1183609c94e04389514c4b954d1c9641ddc2 Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Sat, 6 Apr 2019 10:55:30 -0700 Subject: [PATCH 12/17] Add the int8 group conv support on x86 --- topi/python/topi/x86/conv2d.py | 23 ++++- .../test_topi_group_conv2d_NCHWc_int8.py | 93 +++++++++++++++++++ 2 files changed, 113 insertions(+), 3 deletions(-) create mode 100644 topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index bc2a90bd92749..46ef7ed4d5195 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -453,10 +453,11 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn if data.dtype == 'uint8': - oc_chunk, _, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) else: - oc_chunk, _, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn + groups = ic_chunk // ic_chunk_group if cfg.is_fallback: _get_default_config(cfg, tvm.placeholder((n, in_channel, ih, iw), dtype=data.dtype), @@ -480,7 +481,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, kh = tvm.reduce_axis((0, kernel_height), name='kh') kw = tvm.reduce_axis((0, kernel_width), name='kw') - if data.dtype == 'uint8': + if data.dtype == 'uint8' and groups == 1: assert out_dtype == "int32", \ "INT8 convolution requires input dtype = uint8 and output dtype=int32" # Intel performs dot product of 2 "4" Int8 values @@ -499,6 +500,22 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, oc_block, ic_s_inner].astype(out_dtype), axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + elif data.dtype == 'uint8': + # for int8 group conv support + n_elems = 4 + ic_chunk = in_channel//ic_bn + ic_outer = tvm.reduce_axis((0, ic_chunk//groups), name='ic_outer') + ic_f_inner = tvm.reduce_axis((0, ic_bn//n_elems), name='ic_f_inner') + ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block: + tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*(ic_chunk//groups)+ic_outer, oh*HSTR+kh, ow*WSTR+kw, + ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * + kernel[occ, ic_outer, kh, kw, ic_f_inner, + oc_block, ic_s_inner].astype(out_dtype), + axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), + name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") + # else: fp implementation return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: tvm.sum(data_pad[n, ic//ic_bn, oh*HSTR+kh, ow*WSTR+kw, diff --git a/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py new file mode 100644 index 0000000000000..34a437838cc67 --- /dev/null +++ b/topi/tests/python/test_topi_group_conv2d_NCHWc_int8.py @@ -0,0 +1,93 @@ +"""Test for NCHW[x]c convolution""" + +import numpy as np +import tvm +from tvm import autotvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + +from common import get_all_backend + +def _transform_data(data, bn): + # NCHW -> NCHW[x]c + batch_size, channel, height, width = data.shape + data = np.reshape(data, (batch_size, channel//bn, bn, height, width)) + data = np.transpose(data, (0, 1, 3, 4, 2)) + return data + +def _transform_kernel(kernel, ic_bn, oc_bn): + # OIHW -> OIHW[x]i[x]o + out_channel, in_channel, kh, kw = kernel.shape + kernel = np.reshape(kernel, (out_channel//oc_bn, oc_bn, in_channel//ic_bn, ic_bn//4, kh, kw, 4)) + kernel = np.transpose(kernel, (0, 2, 4, 5, 3, 1, 6)) + return kernel + +def verify_group_conv2d_NCHWc_int8(batch, in_channel, groups, in_size, num_filter, kernel, stride, + padding, dilation=1, add_bias=False, add_relu=False, dtype="int32"): + assert dilation == 1, "conv2d_NCHWc does not support dilation for now." + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % + (batch, in_channel, groups, in_size, num_filter, kernel, stride, padding)) + + in_height = in_width = in_size + + # for testing functionality, + # we choose arbitrary block size that can divide the channel, + # regardless of the performance. + oc_block = 1 + for bn in range(16, 0, -1): + if num_filter % bn == 0: + oc_block = bn + break + + ic_block = 8 + autotvm.DispatchContext.current.silent = True + A = tvm.placeholder((batch, in_channel//ic_block, in_height, in_width, ic_block), name='A', dtype='uint8') + W = tvm.placeholder((num_filter//oc_block, in_channel//ic_block//groups, kernel, kernel, ic_block//4, oc_block, 4), name='W', dtype='int8') + + @memoize("topi.tests.test_topi_conv2d_NCHWc_int8.verify_conv2d_NCHWc_int8") + def get_ref_data(): + a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype("uint8") + w_np = np.random.uniform(size=(num_filter, in_channel//groups, kernel, kernel)).astype("int8") + c_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding, groups) + return _transform_data(a_np, ic_block), _transform_kernel(w_np, ic_block, oc_block), \ + _transform_data(c_np, oc_block) + + a_np, w_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding), + (dilation, dilation), + layout='NCHW%dc'%ic_block, + out_layout="NCHW%dc"%oc_block, + out_dtype=dtype) + s = topi.generic.schedule_conv2d_NCHWc([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + func = tvm.build(s, [A, W, C], device, + name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % + (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + # print(tvm.lower(s, [A, W, C], simple_mode=True)) + func(a, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3) + + for device in ["llvm -mcpu=skylake-avx512"]: + with autotvm.tophub.context(device): # load tophub pre-tuned parameters + check_device(device) + + +def test_conv2d_NCHWc(): + # ResNet50 workloads + verify_group_conv2d_NCHWc_int8(1, 256, 32, 224, 64, 7, 2, 3) + +if __name__ == "__main__": + test_conv2d_NCHWc() From 42ae2a0fdaaf849808f62306f8ed148567f3289c Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 13:54:17 -0700 Subject: [PATCH 13/17] Support the 1x1 int8 conv with NHWC layout and weight packing --- topi/python/topi/x86/conv2d.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 46ef7ed4d5195..95671646bb4ca 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -64,7 +64,10 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): oc, _, kh, kw = kshape elif layout == 'NHWC': n, h, w, ic = dshape +<<<<<<< HEAD oc, _, kh, kw = kshape +======= +>>>>>>> Support the 1x1 int8 conv with NHWC layout and weight packing elif pat.match(layout) is not None: n, ic_chunk, h, w, ic_bn = dshape if data.dtype == 'uint8': From 9411ef6ccbf3f7277372d96ca11765f9b63ad209 Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 15:09:00 -0700 Subject: [PATCH 14/17] memoize the test result --- .../python/test_topi_conv2d_NCHWc_int8.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 topi/tests/python/test_topi_conv2d_NCHWc_int8.py diff --git a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py new file mode 100644 index 0000000000000..2cf32dbc040e2 --- /dev/null +++ b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py @@ -0,0 +1,71 @@ +"""Example code to do convolution.""" +import os +import numpy as np +import tvm +from tvm import autotvm +from tvm.autotvm.task.space import FallbackConfigEntity +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + + +def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): + in_height = in_width = in_size + + A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') + W = tvm.placeholder((kernel, kernel, in_channel, num_filter, kernel, kernel), name='W', dtype='int8') + + a_shape = get_const_tuple(A.shape) + w_shape = (kernel, kernel, in_channel, num_filter) + dtype = A.dtype + + #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(W.dtype) + dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) + b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) + return a_np, w_np, b_np + a_np, w_np, b_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + + with tvm.target.create(device): + B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") + s = topi.generic.schedule_conv2d_nhwc([B]) + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) + func = tvm.build(s, [A, W, B], device) + func(a, w, b) + tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) + + for device in ['llvm -mcpu=skylake-avx512']: + check_device(device) + + +class DefaultFallback(autotvm.FallbackContext): + def _query_inside(self, target, workload): + key = (target, workload) + if key in self.memory: + return self.memory[key] + cfg = FallbackConfigEntity() + cfg.template_key = 'direct' + self.memory[key] = cfg + return cfg + + +def test_conv2d_NCHWc(): + autotvm.DispatchContext.current.silent = True + with DefaultFallback(): + verify_conv2d_NCHWc_int8(1, 256, 32, 256, 7, 7, 0) + + +if __name__ == "__main__": + test_conv2d_NCHWc() From da200ddf83451fecb3413063ae89f59cf1679e74 Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Fri, 5 Apr 2019 15:14:49 -0700 Subject: [PATCH 15/17] fix the wrong file --- .../python/test_topi_conv2d_NCHWc_int8.py | 71 ------------------- 1 file changed, 71 deletions(-) delete mode 100644 topi/tests/python/test_topi_conv2d_NCHWc_int8.py diff --git a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py b/topi/tests/python/test_topi_conv2d_NCHWc_int8.py deleted file mode 100644 index 2cf32dbc040e2..0000000000000 --- a/topi/tests/python/test_topi_conv2d_NCHWc_int8.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Example code to do convolution.""" -import os -import numpy as np -import tvm -from tvm import autotvm -from tvm.autotvm.task.space import FallbackConfigEntity -import topi -import topi.testing -from tvm.contrib.pickle_memoize import memoize -from topi.util import get_const_tuple - - -def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1): - in_height = in_width = in_size - - A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A', dtype='uint8') - W = tvm.placeholder((kernel, kernel, in_channel, num_filter, kernel, kernel), name='W', dtype='int8') - - a_shape = get_const_tuple(A.shape) - w_shape = (kernel, kernel, in_channel, num_filter) - dtype = A.dtype - - #@memoize("topi.tests.test_topi_conv2d_1x1_nhwc_pack_int8.verify_nhwc.v2") - def get_ref_data(): - a_np = np.random.uniform(size=a_shape).astype(dtype) - w_np = np.random.uniform(size=w_shape).astype(W.dtype) - dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1)) - b_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding) - return a_np, w_np, b_np - a_np, w_np, b_np = get_ref_data() - - def check_device(device): - ctx = tvm.context(device, 0) - if not ctx.exist: - print("Skip because %s is not enabled" % device) - return - print("Running on target: %s" % device) - - with tvm.target.create(device): - B = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NHWC', out_dtype="int32") - s = topi.generic.schedule_conv2d_nhwc([B]) - a = tvm.nd.array(a_np, ctx) - w = tvm.nd.array(w_np, ctx) - b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) - func = tvm.build(s, [A, W, B], device) - func(a, w, b) - tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) - - for device in ['llvm -mcpu=skylake-avx512']: - check_device(device) - - -class DefaultFallback(autotvm.FallbackContext): - def _query_inside(self, target, workload): - key = (target, workload) - if key in self.memory: - return self.memory[key] - cfg = FallbackConfigEntity() - cfg.template_key = 'direct' - self.memory[key] = cfg - return cfg - - -def test_conv2d_NCHWc(): - autotvm.DispatchContext.current.silent = True - with DefaultFallback(): - verify_conv2d_NCHWc_int8(1, 256, 32, 256, 7, 7, 0) - - -if __name__ == "__main__": - test_conv2d_NCHWc() From d5190c9c05740b2719d369da2b90f43bf0f853ba Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Mon, 8 Apr 2019 22:10:20 -0700 Subject: [PATCH 16/17] fix merge conflict --- topi/python/topi/x86/conv2d.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index e9e63d2383b84..7775ca6f43b85 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -514,11 +514,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, kernel[occ, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner].astype(out_dtype), axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), -<<<<<<< HEAD name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") -======= - name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") ->>>>>>> cb0c00dbdb970f0e39b0d66e2ac4f8d82ecf6685 # else: fp implementation return tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block: From 87cbfbfd2c88d08f42015d611571be2ea4afef48 Mon Sep 17 00:00:00 2001 From: lingyiliu Date: Mon, 8 Apr 2019 22:43:33 -0700 Subject: [PATCH 17/17] fix linter --- topi/python/topi/nn/conv2d.py | 4 ++-- topi/python/topi/x86/conv2d.py | 19 ++++++++++++------- topi/python/topi/x86/conv2d_avx_1x1.py | 6 +++--- 3 files changed, 17 insertions(+), 12 deletions(-) diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 6b739ee030577..dc8312cb0d200 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -28,8 +28,8 @@ # workload description of conv2d Workload = namedtuple('Workload', - ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups', 'out_filter', - 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups', + 'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) @tvm.target.generic_func def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=None): diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 7775ca6f43b85..0e88eb8faeb31 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -37,7 +37,8 @@ logger = logging.getLogger('topi') -def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout='NCHW'): +def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, + layout='NCHW'): """ Get default schedule config for the workload """ @@ -280,10 +281,10 @@ def traverse(op): conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args) else: raise ValueError("Only support 1x1 kernel with " - "schedule template.") + "schedule template.") else: raise ValueError("Not support this data type {} with " - "schedule template.".format(data.dtype)) + "schedule template.".format(data.dtype)) elif 'conv2d_nhwc' in op.tag: conv = op.output(0) @@ -453,9 +454,11 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn if data.dtype == 'uint8': - oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \ + get_const_tuple(kernel.shape) else: - oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ + get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn groups = ic_chunk // ic_chunk_group @@ -500,7 +503,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, oc_block, ic_s_inner].astype(out_dtype), axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") - elif data.dtype == 'uint8': + if data.dtype == 'uint8': # for int8 group conv support n_elems = 4 ic_chunk = in_channel//ic_bn @@ -509,7 +512,9 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') oshape = (n, oc_chunk, out_height, out_width, oc_bn) return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block: - tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*(ic_chunk//groups)+ic_outer, oh*HSTR+kh, ow*WSTR+kw, + tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*\ + (ic_chunk//groups)+ic_outer, + oh*HSTR+kh, ow*WSTR+kw, ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * kernel[occ, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner].astype(out_dtype), diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index 29bb0802e3c59..4994d4580ab58 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -269,7 +269,7 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape - kernel_h, kernel_w, num_filter, channel = Filter.shape + kernel_h, kernel_w, num_filter, channel = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 @@ -287,7 +287,8 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o # packing the Filter to let memory access be consecutive for AVX512 intrinsic # Done in pre-compute stage packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4) - PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e], name="packed_filter") + PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e], + name="packed_filter") rc = tvm.reduce_axis((0, in_channel), name='rc') ry = tvm.reduce_axis((0, kernel_h), name='ry') @@ -352,4 +353,3 @@ def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last): s[O].vectorize(oc_block) return s -