Skip to content

Commit

Permalink
fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
lly-zero-one committed Apr 9, 2019
1 parent d5190c9 commit 87cbfbf
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
4 changes: 2 additions & 2 deletions topi/python/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@

# workload description of conv2d
Workload = namedtuple('Workload',
['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])
['in_dtype', 'out_dtype', 'height', 'width', 'in_filter', 'groups',
'out_filter', 'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])

@tvm.target.generic_func
def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=None):
Expand Down
19 changes: 12 additions & 7 deletions topi/python/topi/x86/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@

logger = logging.getLogger('topi')

def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False, layout='NCHW'):
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise=False,
layout='NCHW'):
"""
Get default schedule config for the workload
"""
Expand Down Expand Up @@ -280,10 +281,10 @@ def traverse(op):
conv2d_avx_1x1._schedule_conv_nhwc_pack_int8(*args)
else:
raise ValueError("Only support 1x1 kernel with "
"schedule template.")
"schedule template.")
else:
raise ValueError("Not support this data type {} with "
"schedule template.".format(data.dtype))
"schedule template.".format(data.dtype))

elif 'conv2d_nhwc' in op.tag:
conv = op.output(0)
Expand Down Expand Up @@ -453,9 +454,11 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape)
in_channel = ic_chunk * ic_bn
if data.dtype == 'uint8':
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple(kernel.shape)
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = \
get_const_tuple(kernel.shape)
else:
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = get_const_tuple(kernel.shape)
oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn = \
get_const_tuple(kernel.shape)
num_filter = oc_chunk * oc_bn
groups = ic_chunk // ic_chunk_group

Expand Down Expand Up @@ -500,7 +503,7 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
oc_block, ic_s_inner].astype(out_dtype),
axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner]),
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
elif data.dtype == 'uint8':
if data.dtype == 'uint8':
# for int8 group conv support
n_elems = 4
ic_chunk = in_channel//ic_bn
Expand All @@ -509,7 +512,9 @@ def _declaration_conv_NCHWc(cfg, data, kernel, strides,
ic_s_inner = tvm.reduce_axis((0, n_elems), name='ic_s_inner')
oshape = (n, oc_chunk, out_height, out_width, oc_bn)
return tvm.compute(oshape, lambda n, occ, oh, ow, oc_block:
tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*(ic_chunk//groups)+ic_outer, oh*HSTR+kh, ow*WSTR+kw,
tvm.sum(data_pad[n, (occ*oc_bn//(oc_chunk*oc_bn//groups))*\
(ic_chunk//groups)+ic_outer,
oh*HSTR+kh, ow*WSTR+kw,
ic_f_inner * n_elems + ic_s_inner].astype(out_dtype) *
kernel[occ, ic_outer, kh, kw, ic_f_inner,
oc_block, ic_s_inner].astype(out_dtype),
Expand Down
6 changes: 3 additions & 3 deletions topi/python/topi/x86/conv2d_avx_1x1.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o
dilation_h, dilation_w = dilation

batch, in_height, in_width, in_channel = Input.shape
kernel_h, kernel_w, num_filter, channel = Filter.shape
kernel_h, kernel_w, num_filter, channel = Filter.shape

# compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
Expand All @@ -287,7 +287,8 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o
# packing the Filter to let memory access be consecutive for AVX512 intrinsic
# Done in pre-compute stage
packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4)
PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e], name="packed_filter")
PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e],
name="packed_filter")

rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
Expand Down Expand Up @@ -352,4 +353,3 @@ def _schedule_conv_nhwc_pack_int8(s, cfg, data, conv_out, last):
s[O].vectorize(oc_block)

return s

0 comments on commit 87cbfbf

Please sign in to comment.