diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py index 6b739ee030577..dc8312cb0d200 100644 --- a/topi/python/topi/nn/conv2d.py +++ b/topi/python/topi/nn/conv2d.py @@ -28,8 +28,8 @@ # workload description of conv2d Workload = namedtuple('Workload', - ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups', 'out_filter', - 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) + ['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups', + 'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride']) @tvm.target.generic_func def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=None): diff --git a/topi/python/topi/x86/conv2d.py b/topi/python/topi/x86/conv2d.py index 7775ca6f43b85..0e88eb8faeb31 100644 --- a/topi/python/topi/x86/conv2d.py +++ b/topi/python/topi/x86/conv2d.py @@ -37,7 +37,8 @@ logger = logging.getLogger('topi') -def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout='NCHW'): +def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, + layout='NCHW'): """ Get default schedule config for the workload """ @@ -280,10 +281,10 @@ def traverse(op): conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args) else: raise ValueError("Only support 1x1 kernel with " - "schedule template.") + "schedule template.") else: raise ValueError("Not support this data type {} with " - "schedule template.".format(data.dtype)) + "schedule template.".format(data.dtype)) elif 'conv2d_nhwc' in op.tag: conv = op.output(0) @@ -453,9 +454,11 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) in_channel = ic_chunk * ic_bn if data.dtype == 'uint8': - oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \ + get_const_tuple(kernel.shape) else: - oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape) + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \ + get_const_tuple(kernel.shape) num_filter = oc_chunk * oc_bn groups = ic_chunk // ic_chunk_group @@ -500,7 +503,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, oc_block, ic_s_inner].astype(out_dtype), axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]), name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8") - elif data.dtype == 'uint8': + if data.dtype == 'uint8': # for int8 group conv support n_elems = 4 ic_chunk = in_channel//ic_bn @@ -509,7 +512,9 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides, ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner') oshape = (n, oc_chunk, out_height, out_width, oc_bn) return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block: - tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*(ic_chunk//groups)+ic_outer, oh*HSTR+kh, ow*WSTR+kw, + tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*\ + (ic_chunk//groups)+ic_outer, + oh*HSTR+kh, ow*WSTR+kw, ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) * kernel[occ, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner].astype(out_dtype), diff --git a/topi/python/topi/x86/conv2d_avx_1x1.py b/topi/python/topi/x86/conv2d_avx_1x1.py index 29bb0802e3c59..4994d4580ab58 100644 --- a/topi/python/topi/x86/conv2d_avx_1x1.py +++ b/topi/python/topi/x86/conv2d_avx_1x1.py @@ -269,7 +269,7 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape - kernel_h, kernel_w, num_filter, channel = Filter.shape + kernel_h, kernel_w, num_filter, channel = Filter.shape # compute the output shape dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 @@ -287,7 +287,8 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o # packing the Filter to let memory access be consecutive for AVX512 intrinsic # Done in pre-compute stage packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4) - PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e], name="packed_filter") + PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e], + name="packed_filter") rc = tvm.reduce_axis((0, in_channel), name='rc') ry = tvm.reduce_axis((0, kernel_h), name='ry') @@ -352,4 +353,3 @@ def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last): s[O].vectorize(oc_block) return s -