From 27d72dce2353cd84700d9bd8d260f3aeaf22261c Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Mon, 8 Aug 2022 02:08:58 -0700 Subject: [PATCH 1/6] Add depthwise conv2d schedule for Cortex-M DSP depthwise_conv2d kernel re-arranging fast bytecode for dsp copy/modify helper code Bugfixes from code testing Much of the depthwise conv2d schedule V1 DSP DWC2D black formatting Minor work to address comments --- python/tvm/relay/op/strategy/arm_cpu.py | 23 ++ python/tvm/topi/arm_cpu/depthwise_conv2d.py | 19 + .../arm_cpu/mprofile/dsp/depthwise_conv2d.py | 378 ++++++++++++++++++ 3 files changed, 420 insertions(+) create mode 100644 python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index ba28b6c7c31c..15cc76885f53 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -235,6 +235,29 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.arm_cpu", ) + + # Optimized special case depthwiseConv2D operation. Requires a 3x3 kernel, a + # NHWC layout, a HWOI kernel layout (which we would ideally rearrange), no dilation, + # "SAME" padding, int8 inputs and outputs, the same number of input and output + # channels, and for that channel count to be divisible by 4. + # + # Additional work could remove some of these restrictions. + + elif ( + isa.has_dsp_support + and kernel.shape[0] == kernel.shape[1] == 3 + and dilation_w == dilation_h == 1 + and kernel.shape[3] == 1 # channel_multiplier == 1 + and data.dtype == "int8" + and padding == "SAME" + and data.shape[3] % 4 == 0 + ): + strategy.add_implementation( + wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nhwc), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), + name="depthwise_conv2d_nhwc_dsp.arm_cpu", + ) + else: logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.") strategy.add_implementation( diff --git a/python/tvm/topi/arm_cpu/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/depthwise_conv2d.py index c21480724ae4..333db3d5e014 100644 --- a/python/tvm/topi/arm_cpu/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/depthwise_conv2d.py @@ -28,6 +28,11 @@ from .tensor_intrin import smlal_int16_int32 from .arm_utils import is_aarch64_arm +from .mprofile.dsp.depthwise_conv2d import ( + depthwise_conv2d_nhwc_dsp_compute, + depthwise_conv2d_nhwc_dsp_schedule, +) + @autotvm.register_topi_compute("depthwise_conv2d_nchw.arm_cpu") def depthwise_conv2d_nchw(_, data, kernel, strides, padding, dilation, out_dtype): @@ -699,3 +704,17 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, conv, output, last): s[kernel_vec].parallel(co) return s + + +@autotvm.register_topi_compute("depthwise_conv2d_nhwc_dsp.arm_cpu") +def depthwise_conv2d_nhwc_dsp(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute conv2d_nhwc with v7e-m DSP instructions.""" + return depthwise_conv2d_nhwc_dsp_compute( + cfg, data, kernel, strides, padding, dilation, out_dtype + ) + + +@autotvm.register_topi_schedule("depthwise_conv2d_nhwc_dsp.arm_cpu") +def schedule_depthwise_conv2d_nhwc_dsp(cfg, outs): + """Create schedule for conv2d_nhwc_dsp""" + return depthwise_conv2d_nhwc_dsp_schedule(cfg, outs) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py new file mode 100644 index 000000000000..c023993e91d1 --- /dev/null +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py @@ -0,0 +1,378 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Direct implementation of conv2d.""" + +from tvm import autotvm +from tvm.autotvm.task import deserialize_args +from tvm import te +from tvm.topi.utils import simplify, traverse_inline +from tvm.topi.nn.pad import pad +from tvm.topi.nn.utils import get_pad_tuple +from tvm.tir.expr import Mul + +# For depthwise_conv2d, kernels are normally given in HWOI format, +# which when input_channels = output channels, we will call HWC. +# This is bad, as we want "related" parts of the kernel to be next +# to each other, so we can use __SMLAD later. +# +# Consider a 3x3 int8 kernel with no bias vector, with eight +# channels. Let us specify entries in the kernel as H_W_C - i.e. +# where 0_2_3 represents the rightmost position in the first row +# of channel 4/8 (4 because of zero indexing). Each [ ] represents +# a 32-bit integer. We currently store the kernel as: +# +# 0 ................................31 +# [ 0_0_0 || 0_0_1 || 0_0_2 || 0_0_3 ] [ 0_0_4 || 0_0_5 || 0_0_6 || 0_0_7 ] +# [ 0_1_0 || 0_1_1 || 0_1_2 || 0_1_3 ] [ 0_1_4 || 0_1_5 || 0_1_6 || 0_1_7 ] +# [ 0_2_0 || 0_2_1 || 0_2_2 || 0_2_3 ] [ 0_2_4 || 0_2_5 || 0_2_6 || 0_2_7 ] +# [ 1_0_0 || 1_0_1 || 1_0_2 || 1_0_3 ] [ 1_0_4 || 1_0_5 || 1_0_6 || 1_0_7 ] +# [ 1_1_0 || 1_1_1 || 1_1_2 || 1_1_3 ] [ 1_1_4 || 1_1_5 || 1_1_6 || 1_1_7 ] +# [ 1_2_0 || 1_2_1 || 1_2_2 || 1_2_3 ] [ 1_2_4 || 1_2_5 || 1_2_6 || 1_2_7 ] +# [ 2_0_0 || 2_0_1 || 2_0_2 || 2_0_3 ] [ 2_0_4 || 2_0_5 || 2_0_6 || 2_0_7 ] +# [ 2_1_0 || 2_1_1 || 2_1_2 || 2_1_3 ] [ 2_1_4 || 2_1_5 || 2_1_6 || 2_1_7 ] +# [ 2_2_0 || 2_2_1 || 2_2_2 || 2_2_3 ] [ 2_2_4 || 2_2_5 || 2_2_6 || 2_2_7 ] +# +# Let 0x00 be all zeros. We rearrange into: +# +# 0 ................................31 +# [ 0_0_0 || 0_0_1 || 0_1_0 || 0_1_1 ] [ 0_0_2 || 0_0_3 || 0_1_2 || 0_1_3 ] +# [ 0_2_0 || 0_2_1 || 1_0_0 || 1_0_1 ] [ 0_2_2 || 0_2_3 || 1_0_2 || 1_0_3 ] +# [ 1_1_0 || 1_1_1 || 1_2_0 || 1_2_1 ] [ 1_1_2 || 1_1_3 || 1_2_2 || 1_2_3 ] +# [ 2_0_0 || 2_0_1 || 2_1_0 || 2_1_1 ] [ 2_0_2 || 2_0_3 || 2_1_2 || 2_1_3 ] +# [ 2_2_0 || 2_2_1 || 0x000 || 0x000 ] [ 2_2_2 || 2_2_3 || 0x000 || 0x000 ] +# [ 0_0_4 || 0_0_5 || 0_1_4 || 0_1_5 ] [ 0_0_6 || 0_0_7 || 0_1_6 || 0_1_7 ] +# [ 0_2_4 || 0_2_5 || 1_0_4 || 1_0_5 ] [ 0_2_6 || 0_2_7 || 1_0_6 || 1_0_7 ] +# [ 1_1_4 || 1_1_5 || 1_2_4 || 1_2_5 ] [ 1_1_6 || 1_1_7 || 1_2_6 || 1_2_7 ] +# [ 2_0_4 || 2_0_5 || 2_1_4 || 2_1_5 ] [ 2_0_6 || 2_0_7 || 2_1_6 || 2_1_7 ] +# [ 2_2_4 || 2_2_5 || 0x000 || 0x000 ] [ 2_2_6 || 2_2_7 || 0x000 || 0x000 ] +# +# This saves us six operations comapred to the original ordering, as we +# do not need halfword packing instructions. +# +# This kernel re-arranging function will be used for 3x3 kernels (as that +# is all this DSP implementation currently supports) but would work with +# any M*N kernel such that M*N is odd. + + +def _rearrange_kernel(kernel): + # Kernel must be HWC format. + K_H, K_W, C, _ = get_const_tuple(kernel.shape) + assert C % 4 == 0 + + # TODO remove this restriction + assert (K_W * K_H) % 2 == 1 + + def fcompute(c_o, pos, c_i): + channel = (2 * (pos % 2)) + (c_i % 2) + (4 * c_o) + true_pos_index = 2 * (pos // 2) + (c_i // 2) + + return tir.if_then_else( + true_pos_index < (K_H * K_W), + kernel[true_pos_index // K_W, true_pos_index % K_W, channel, 0], + tir.const(0, "int8"), + ) + + return te.compute( + (C // 4, K_H * K_W + 1, 4), lambda co, pos, ci: fcompute(co, pos, ci), name="packed_kernel" + ) + + +def depthwise_conv2d_nhwc_dsp(*args, **kwargs): + """Defines the v7e-m DSP instructions of depthwise_conv2d.""" + assert not kwargs, "Do not support kwargs in template function call" + args = deserialize_args(args) + data, kernel = args[:2] + layout = args[-2] + cfg = autotvm.get_config() + args = [cfg] + args + assert layout == "NHWC" + conv = depthwise_conv2d_nhwc_dsp_compute(*args) + sched = depthwise_conv2d_nhwc_dsp_schedule(cfg, [data, kernel, conv]) + return sched, [data, kernel, conv] + + +depthwise_conv2d_nhwc_dsp.template_key = "dsp" +depthwise_conv2d_nhwc_dsp.default_data_layout = "NHWC" +depthwise_conv2d_nhwc_dsp.default_kernel_layout = "HWOI" + + +def depthwise_conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): + """Compute function for v7e-m DSP instructions of DepthwiseConv2D. Has a lot of requirements + for use - not not all apply, the fallback implementation will be used instead.""" + assert isinstance(strides, int) or len(strides) == 2 + assert isinstance(dilation, int) or len(dilation) == 2 + + if isinstance(strides, int): + STRIDE_H = STRIDE_W = strides + else: + STRIDE_H, STRIDE_W = strides + + # We do not support dilation currently. It would be possible, but it would require + # modifying the way the kernel is packed. Gnarly. + if isinstance(dilation, int): + DILATION_H = DILATION_W = dilation + else: + DILATION_H, DILATION_H = dilation + assert DILATION_H == DILATION_H == 1 + + B, H, W, C = data.shape + K_H, K_W, _, _ = kernel.shape + + # We require that the number of channels be divisible by 4. This restriction could + # be removed with strip mining if people cared. + assert C % 4 == 0 + + # We don't support different numbers of input and output channels. + assert C == kernel.shape[2] + assert kernel.shape[3] == 1 + + # The int16 case could also be optimized here, but would require writing a whole new + # micro kernel. Probably not worth it. + assert out_dtype == "int8" + + # This can pretty easily be generalized in the future. Likely worth doing, and this + # function was written to make doing so easy. Should only require adding more calls + # to QUAD_CHANNEL_REARRANGE_SUM. + assert K_W == K_H == 3 + + # We do not currently support custom padding. Would be pretty easy to implement. + assert padding == "SAME" or padding == "VALID" + + # Padding the data requires COPYING THE ENTIRE INPUT TENSOR, which + # is slow and bad. We should really implement a strip mining + # routine to avoid this, but TVM has terrible support for that. + + if padding == "SAME": + # This assumption makes the logic easier. Could be removed with work. + assert H % STRIDE_H == W % STRIDE_W == 0 + + OUT_H = H // STRIDE_H + OUT_W = W // STRIDE_W + + # Padding this way is weird and quirky, and we only do it to match TFLite. + pad_top = 1 if STRIDE_H == 1 else 0 + pad_left = 1 if STRIDE_W == 1 else 0 + + data_padded = pad( + data, [0, pad_top, pad_left, 0], [0, K_H // 2, K_W // 2, 0], name="padded_data" + ) + + elif PADDING_STRATEGY == "VALID": + assert H > K_H and W > K_W + OUT_H = (H - K_H) // STRIDE_H + 1 + OUT_W = (W - K_W) // STRIDE_W + 1 + data_padded = data + + else: + raise RuntimeError() + _, P_H, P_W, _ = data_padded.shape + + packed_kernel = _rearrange_kernel(kernel) + kh_i = te.reduce_axis((0, K_H), name="kh_i") + kw_i = te.reduce_axis((0, K_W), name="kw_i") + return te.compute( + (B, OUT_H, OUT_W, C), + lambda h, i, j, k: te.sum( + DATA_PAD[h, (i * STRIDE_H) + kh_i, (j * STRIDE_W) + kw_i, k] + * PACKED_KER[ + k // 4, + (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), + (2 * ((kh_i + kw_i) % 2)) + (k % 2), + ], + axis=(kh_i, kw_i), + ), + name="depthwise_conv2d", + tag=f"depthwise_conv2d_nhwc_{P_H}_{P_W}_dsp", + ) + + +def depthwise_conv2d_nhwc_dsp_schedule(cfg, outs): + + """Schedule function for v7e-m DSP instructions of conv2d.""" + schedule = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if "depthwise_conv2d_nhwc" not in op.tag: + return + + # extract tensors + output = op.output(0) + padded_data = conv_out.op.input_tensors[0] + packed_kernel = conv_out.op.input_tensors[1] + kernel = packed_kernel.op.input_tensors[0] + + B, P_H, P_W, C = padded_data.shape + K_H, K_W, _, _ = kernel.shape + suffix = "".join(random.choices(string.ascii_uppercase, k=8)) + + b_ax, y_ax, x_ax, c_ax = schedule[output].op.axis + ky_ax, kx_ax = schedule[output].op.reduce_axis + c_ax_o, c_ax_i = schedule[output].split(c_ax, factor=4) + schedule[output].reorder(b_ax, c_ax_o, y_ax, x_ax, ky_ax, kx_ax, c_ax_i) + + quad_channel_convolve = intrin_quad_channel_convolve_3x3(P_H, P_W, C, K_H, K_W, suffix) + s[CONVOLVED].tensorize(ky_ax, gemv) + sched[output].pragma( + b_ax, "import_c", quad_channel_convolve_3x3_impl(P_H, P_W, C, K_H, K_W, suffix) + ) + + traverse_inline(sched, outs[-1].op, _callback) + return sched + + +def intrin_quad_channel_convolve_3x3(P_H, P_W, C, K_H, K_W, suffix): + a = te.placeholder((K_H, K_W, 4), name="a", dtype="int8") + b = te.placeholder((K_H * K_W + 1, 4), name="b", dtype="int8") + kh_i = te.reduce_axis((0, K_H), name="kh_i") + kw_i = te.reduce_axis((0, K_W), name="kw_i") + + c = te.compute( + (4,), + lambda k: te.sum( + a[kh_i, kw_i, k] + * b[ + (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), + (2 * ((kh_i + kw_i) % 2)) + (k % 2), + ], + axis=(kh_i, kw_i), + ), + name="c", + ) + + Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[P_W * C, C, 1]) + Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[4, 1]) + Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1]) + + def intrin_func(ins, outs): + ib = tvm.tir.ir_builder.create() + aa, bb = ins + cc = outs[0] + ib.emit( + tvm.tir.call_extern( + "int32", + f"kernel_convolve_noedge_{P_H}_{P_W}_{C}_{K_H}_{K_W}_{suffix}", + cc.access_ptr("w"), + aa.access_ptr("r"), + bb.access_ptr("r"), + ) + ) + return ib.get() + + return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) + + +def quad_channel_convolve_3x3_impl(P_H, P_W, C, K_H, K_W, suffix): + return ( + textwrap.dedent( + f""" + + #include + + // __SXTB16(_ROR(X, Y)) is combined into one assembly instruction + + #define QUAD_CHANNEL_TENSOR_REARRANGE_SUM_DSP( \ + arranged_kernel, \ + tensor_v0_c3210, tensor_v1_c3210, \ + sum0, sum1, sum2, sum3) {{ \ + \ + uint32_t tensor_v0_c20 = __SXTB16(tensor_v0_c3210); \ + uint32_t tensor_v0_c31 = __SXTB16(__ROR(tensor_v0_c3210, 8)); \ + uint32_t tensor_v1_c20 = __SXTB16(tensor_v1_c3210); \ + uint32_t tensor_v1_c31 = __SXTB16(__ROR(tensor_v1_c3210, 8)); \ + \ + uint32_t kernel_v1c1_v1c0_v0c1_v0c0 = *arranged_kernel++; \ + uint32_t kernel_v1c3_v1c2_v0c3_v0c2 = *arranged_kernel++; \ + \ + uint32_t kernel_v10_c0 = __SXTB16(kernel_v1c1_v1c0_v0c1_v0c0); \ + uint32_t kernel_v10_c1 = __SXTB16(__ROR(kernel_v1c1_v1c0_v0c1_v0c0, 8)); \ + uint32_t kernel_v10_c2 = __SXTB16(kernel_v1c3_v1c2_v0c3_v0c2); \ + uint32_t kernel_v10_c3 = __SXTB16(__ROR(kernel_v1c3_v1c2_v0c3_v0c2, 8)); \ + \ + uint32_t tensor_v10_c0 = __PKHBT(tensor_v0_c20, tensor_v1_c20, 16); \ + uint32_t tensor_v10_c1 = __PKHBT(tensor_v0_c31, tensor_v1_c31, 16); \ + uint32_t tensor_v10_c2 = __PKHTB(tensor_v1_c20, tensor_v0_c20, 16); \ + uint32_t tensor_v10_c3 = __PKHTB(tensor_v1_c31, tensor_v0_c31, 16); \ + \ + sum_c0 = __SMLAD(tensor_v10_c0, kernel_v10_c0, sum_c0); \ + sum_c1 = __SMLAD(tensor_v10_c1, kernel_v10_c1, sum_c1); \ + sum_c2 = __SMLAD(tensor_v10_c2, kernel_v10_c2, sum_c2); \ + sum_c3 = __SMLAD(tensor_v10_c3, kernel_v10_c3, sum_c3); \ + }} + + + /* Here, we want to take the LOWER BYTES of 32 bit numbers v3 - v0 + * and concatenate them as "v3 || v2 || v1 || v0". In C++, this is: + + return ((sum_c0 & 0x000000FF)) + + ((sum_c1 & 0x000000FF) << 8) + + ((sum_c2 & 0x000000FF) << 16) + + ((sum_c3 & 0x000000FF) << 24); + + * Naively, this takes 4x ands, 3x adds, 3x shifts. With optimization flags, + * clang optimizes this down to eight operations: + + mov r12, #255 + and r0, r0, #255 + orr r12, r12, #65280 + and r1, r12, r1, lsl #8 + orr r0, r1, r0 + and r1, r2, #255 + orr r0, r0, r1, lsl #16 + orr r0, r0, r3, lsl #24 + + * But being clever engineers, we can do it in four instructions. I think, + * but have been unable to prove, that fewer is impossible. */ + + #define WRITE_QUAD_BYTE_JOIN_DSP(out, v3, v2, v1, v0) {{ \ + uint32_t v3_00_v1_00 = PKHBT(v1 << 8, v3, 24); \ + uint32_t gg_v2_gg_v0 = PKHBT(v0, v2, 16); \ + out[0] = UXTAB16(v3_00_v1_00, gg_v2_gg_v0); \ + }} + + /* We do four channels at once to get this speed boost. */ + extern "C" int kernel_convolve_noedge_{P_H}_{P_W}_{C}_{K_H}_{K_W}_{suffix}( + uint32_t *out, + uint32_t *tensor, + uint32_t *packed_kernel) {{ + + uint32_t sum_c0 = 0; + uint32_t sum_c1 = 0; + uint32_t sum_c2 = 0; + uint32_t sum_c3 = 0; + + QUAD_CHANNEL_TENSOR_REARRANGE_SUM( + packed_kernel, *tensor, *(tensor + {C // 4}), + sum_c0, sum_c1, sum_c2, sum_c3) + QUAD_CHANNEL_TENSOR_REARRANGE_SUM( + packed_kernel, *(tensor + {(2) * C // 4}), *(tensor + {P_W * (C // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + QUAD_CHANNEL_TENSOR_REARRANGE_SUM( + packed_kernel, *(tensor + {(P_W + 1) * (C // 4)}), *(tensor + {(P_W + 2) * (C // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + QUAD_CHANNEL_TENSOR_REARRANGE_SUM( + packed_kernel, *(tensor + {(2 * P_W) * (C // 4)}), *(tensor + {(2 * P_W + 1) * (C // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + QUAD_CHANNEL_TENSOR_REARRANGE_SUM( + packed_kernel, *(tensor + {(2 * P_W + 2) * (C // 4)}), 0, + sum_c0, sum_c1, sum_c2, sum_c3) + + WRITE_QUAD_BYTE_JOIN_DSP(out, sum_c3, sum_c2, sum_c1, sum_c0); + }} + """ + ), + ) From 2dcb889d296f6af34f9acc766ea5d297fcbf4d2c Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Tue, 30 Aug 2022 06:58:40 -0700 Subject: [PATCH 2/6] Code reorganization and testing Functional DWC2D schedule with test Code cleanup and linting Fix padding to match Relay and add tests Fix test cases --- python/tvm/relay/op/strategy/arm_cpu.py | 7 +- .../arm_cpu/mprofile/dsp/depthwise_conv2d.py | 275 +++++------------- .../dsp/micro_kernel/quad_channel_convolve.py | 178 ++++++++++++ .../strategy/arm_cpu/test_depthwise_conv2d.py | 24 ++ 4 files changed, 280 insertions(+), 204 deletions(-) create mode 100644 python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 15cc76885f53..82f3a2410e0f 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -244,17 +244,16 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): # Additional work could remove some of these restrictions. elif ( - isa.has_dsp_support + target.features.has_dsp and kernel.shape[0] == kernel.shape[1] == 3 and dilation_w == dilation_h == 1 and kernel.shape[3] == 1 # channel_multiplier == 1 and data.dtype == "int8" - and padding == "SAME" and data.shape[3] % 4 == 0 ): strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), + wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nhwc_dsp), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc_dsp), name="depthwise_conv2d_nhwc_dsp.arm_cpu", ) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py index c023993e91d1..3fe67507a41b 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py @@ -14,16 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""ARM Cortex-M DSP schedule for depthwise_conv2d""" -"""Direct implementation of conv2d.""" +import random +import string from tvm import autotvm from tvm.autotvm.task import deserialize_args from tvm import te -from tvm.topi.utils import simplify, traverse_inline +from tvm.topi.utils import traverse_inline, get_const_tuple from tvm.topi.nn.pad import pad -from tvm.topi.nn.utils import get_pad_tuple -from tvm.tir.expr import Mul +from tvm import tir + +from .micro_kernel.quad_channel_convolve import ( + intrin_quad_channel_convolve, + quad_channel_convolve_impl, +) # For depthwise_conv2d, kernels are normally given in HWOI format, # which when input_channels = output channels, we will call HWC. @@ -71,24 +77,27 @@ def _rearrange_kernel(kernel): # Kernel must be HWC format. - K_H, K_W, C, _ = get_const_tuple(kernel.shape) - assert C % 4 == 0 + kernel_h, kernel_w, channels, _ = get_const_tuple(kernel.shape) + assert channels % 4 == 0 - # TODO remove this restriction - assert (K_W * K_H) % 2 == 1 + # This restriction could be removed by only using tir.if_then_else to add padding + # zeros if (kernel_w * kernel_h) % 2 == 1, and filling completely otherwise. + assert (kernel_w * kernel_h) % 2 == 1 def fcompute(c_o, pos, c_i): channel = (2 * (pos % 2)) + (c_i % 2) + (4 * c_o) true_pos_index = 2 * (pos // 2) + (c_i // 2) return tir.if_then_else( - true_pos_index < (K_H * K_W), - kernel[true_pos_index // K_W, true_pos_index % K_W, channel, 0], + true_pos_index < (kernel_h * kernel_w), + kernel[true_pos_index // kernel_w, true_pos_index % kernel_w, channel, 0], tir.const(0, "int8"), ) return te.compute( - (C // 4, K_H * K_W + 1, 4), lambda co, pos, ci: fcompute(co, pos, ci), name="packed_kernel" + (channels // 4, kernel_h * kernel_w + 1, 4), + fcompute, + name="packed_kernel", ) @@ -101,7 +110,7 @@ def depthwise_conv2d_nhwc_dsp(*args, **kwargs): cfg = autotvm.get_config() args = [cfg] + args assert layout == "NHWC" - conv = depthwise_conv2d_nhwc_dsp_compute(*args) + conv = depthwise_conv2d_nhwc_dsp_compute(args) sched = depthwise_conv2d_nhwc_dsp_schedule(cfg, [data, kernel, conv]) return sched, [data, kernel, conv] @@ -118,40 +127,41 @@ def depthwise_conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilat assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(strides, int): - STRIDE_H = STRIDE_W = strides + stride_h = stride_w = strides else: - STRIDE_H, STRIDE_W = strides + stride_h, stride_w = strides # We do not support dilation currently. It would be possible, but it would require # modifying the way the kernel is packed. Gnarly. if isinstance(dilation, int): - DILATION_H = DILATION_W = dilation + dilation_h = dilation_w = dilation else: - DILATION_H, DILATION_H = dilation - assert DILATION_H == DILATION_H == 1 + dilation_h, dilation_w = dilation + assert dilation_h == dilation_w == 1 - B, H, W, C = data.shape - K_H, K_W, _, _ = kernel.shape + batch_size, height, width, channels = data.shape + kernel_h, kernel_w, _, _ = kernel.shape # We require that the number of channels be divisible by 4. This restriction could # be removed with strip mining if people cared. - assert C % 4 == 0 + assert channels % 4 == 0 # We don't support different numbers of input and output channels. - assert C == kernel.shape[2] + assert channels == kernel.shape[2] assert kernel.shape[3] == 1 - # The int16 case could also be optimized here, but would require writing a whole new - # micro kernel. Probably not worth it. - assert out_dtype == "int8" + # We take in int8 as our dtype, but we spit out int32. This is because we cannot + # round until we compute activations. + assert out_dtype == "int32" # This can pretty easily be generalized in the future. Likely worth doing, and this # function was written to make doing so easy. Should only require adding more calls # to QUAD_CHANNEL_REARRANGE_SUM. - assert K_W == K_H == 3 + assert kernel_w == kernel_h == 3 # We do not currently support custom padding. Would be pretty easy to implement. - assert padding == "SAME" or padding == "VALID" + # assert padding == "SAME" or padding == "VALID" + padding = "SAME" # Padding the data requires COPYING THE ENTIRE INPUT TENSOR, which # is slow and bad. We should really implement a strip mining @@ -159,45 +169,49 @@ def depthwise_conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilat if padding == "SAME": # This assumption makes the logic easier. Could be removed with work. - assert H % STRIDE_H == W % STRIDE_W == 0 + assert height % stride_h == width % stride_w == 0 - OUT_H = H // STRIDE_H - OUT_W = W // STRIDE_W + output_h = height // stride_h + output_w = width // stride_w - # Padding this way is weird and quirky, and we only do it to match TFLite. - pad_top = 1 if STRIDE_H == 1 else 0 - pad_left = 1 if STRIDE_W == 1 else 0 + # Note - this padding behavior is DIFFERENT from Tensorflow, which pads the top left if + # stride > 1. Need to investigate and decide which behavior we want. + pad_down = 1 if stride_h == 1 else 0 + pad_right = 1 if stride_w == 1 else 0 - data_padded = pad( - data, [0, pad_top, pad_left, 0], [0, K_H // 2, K_W // 2, 0], name="padded_data" + padded_data = pad( + data, + [0, kernel_h // 2, kernel_w // 2, 0], + [0, pad_down, pad_right, 0], + name="padded_data", ) - elif PADDING_STRATEGY == "VALID": - assert H > K_H and W > K_W - OUT_H = (H - K_H) // STRIDE_H + 1 - OUT_W = (W - K_W) // STRIDE_W + 1 - data_padded = data + elif padding == "VALID": + assert height > kernel_h and width > kernel_w + output_h = (height - kernel_h) // stride_h + 1 + output_w = (width - kernel_w) // stride_w + 1 + padded_data = data else: raise RuntimeError() - _, P_H, P_W, _ = data_padded.shape + _, padded_h, padded_w, _ = padded_data.shape packed_kernel = _rearrange_kernel(kernel) - kh_i = te.reduce_axis((0, K_H), name="kh_i") - kw_i = te.reduce_axis((0, K_W), name="kw_i") + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") return te.compute( - (B, OUT_H, OUT_W, C), + (batch_size, output_h, output_w, channels), lambda h, i, j, k: te.sum( - DATA_PAD[h, (i * STRIDE_H) + kh_i, (j * STRIDE_W) + kw_i, k] - * PACKED_KER[ + padded_data[h, (i * stride_h) + kh_i, (j * stride_w) + kw_i, k].astype("int32") + * packed_kernel[ k // 4, (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), (2 * ((kh_i + kw_i) % 2)) + (k % 2), - ], + ].astype("int32"), axis=(kh_i, kw_i), ), name="depthwise_conv2d", - tag=f"depthwise_conv2d_nhwc_{P_H}_{P_W}_dsp", + tag=f"depthwise_conv2d_nhwc_{padded_h}_{padded_w}_dsp", ) @@ -212,12 +226,12 @@ def _callback(op): # extract tensors output = op.output(0) - padded_data = conv_out.op.input_tensors[0] - packed_kernel = conv_out.op.input_tensors[1] + padded_data = output.op.input_tensors[0] + packed_kernel = output.op.input_tensors[1] kernel = packed_kernel.op.input_tensors[0] - B, P_H, P_W, C = padded_data.shape - K_H, K_W, _, _ = kernel.shape + _, _, padded_w, channels = padded_data.shape + kernel_h, kernel_w, _, _ = kernel.shape suffix = "".join(random.choices(string.ascii_uppercase, k=8)) b_ax, y_ax, x_ax, c_ax = schedule[output].op.axis @@ -225,154 +239,15 @@ def _callback(op): c_ax_o, c_ax_i = schedule[output].split(c_ax, factor=4) schedule[output].reorder(b_ax, c_ax_o, y_ax, x_ax, ky_ax, kx_ax, c_ax_i) - quad_channel_convolve = intrin_quad_channel_convolve_3x3(P_H, P_W, C, K_H, K_W, suffix) - s[CONVOLVED].tensorize(ky_ax, gemv) - sched[output].pragma( - b_ax, "import_c", quad_channel_convolve_3x3_impl(P_H, P_W, C, K_H, K_W, suffix) + quad_channel_convolve = intrin_quad_channel_convolve( + padded_w, channels, kernel_h, kernel_w, suffix ) - - traverse_inline(sched, outs[-1].op, _callback) - return sched - - -def intrin_quad_channel_convolve_3x3(P_H, P_W, C, K_H, K_W, suffix): - a = te.placeholder((K_H, K_W, 4), name="a", dtype="int8") - b = te.placeholder((K_H * K_W + 1, 4), name="b", dtype="int8") - kh_i = te.reduce_axis((0, K_H), name="kh_i") - kw_i = te.reduce_axis((0, K_W), name="kw_i") - - c = te.compute( - (4,), - lambda k: te.sum( - a[kh_i, kw_i, k] - * b[ - (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), - (2 * ((kh_i + kw_i) % 2)) + (k % 2), - ], - axis=(kh_i, kw_i), - ), - name="c", - ) - - Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[P_W * C, C, 1]) - Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[4, 1]) - Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1]) - - def intrin_func(ins, outs): - ib = tvm.tir.ir_builder.create() - aa, bb = ins - cc = outs[0] - ib.emit( - tvm.tir.call_extern( - "int32", - f"kernel_convolve_noedge_{P_H}_{P_W}_{C}_{K_H}_{K_W}_{suffix}", - cc.access_ptr("w"), - aa.access_ptr("r"), - bb.access_ptr("r"), - ) + schedule[output].tensorize(ky_ax, quad_channel_convolve) + schedule[output].pragma( + b_ax, + "import_c", + quad_channel_convolve_impl(padded_w, channels, kernel_h, kernel_w, suffix), ) - return ib.get() - - return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) - - -def quad_channel_convolve_3x3_impl(P_H, P_W, C, K_H, K_W, suffix): - return ( - textwrap.dedent( - f""" - - #include - - // __SXTB16(_ROR(X, Y)) is combined into one assembly instruction - - #define QUAD_CHANNEL_TENSOR_REARRANGE_SUM_DSP( \ - arranged_kernel, \ - tensor_v0_c3210, tensor_v1_c3210, \ - sum0, sum1, sum2, sum3) {{ \ - \ - uint32_t tensor_v0_c20 = __SXTB16(tensor_v0_c3210); \ - uint32_t tensor_v0_c31 = __SXTB16(__ROR(tensor_v0_c3210, 8)); \ - uint32_t tensor_v1_c20 = __SXTB16(tensor_v1_c3210); \ - uint32_t tensor_v1_c31 = __SXTB16(__ROR(tensor_v1_c3210, 8)); \ - \ - uint32_t kernel_v1c1_v1c0_v0c1_v0c0 = *arranged_kernel++; \ - uint32_t kernel_v1c3_v1c2_v0c3_v0c2 = *arranged_kernel++; \ - \ - uint32_t kernel_v10_c0 = __SXTB16(kernel_v1c1_v1c0_v0c1_v0c0); \ - uint32_t kernel_v10_c1 = __SXTB16(__ROR(kernel_v1c1_v1c0_v0c1_v0c0, 8)); \ - uint32_t kernel_v10_c2 = __SXTB16(kernel_v1c3_v1c2_v0c3_v0c2); \ - uint32_t kernel_v10_c3 = __SXTB16(__ROR(kernel_v1c3_v1c2_v0c3_v0c2, 8)); \ - \ - uint32_t tensor_v10_c0 = __PKHBT(tensor_v0_c20, tensor_v1_c20, 16); \ - uint32_t tensor_v10_c1 = __PKHBT(tensor_v0_c31, tensor_v1_c31, 16); \ - uint32_t tensor_v10_c2 = __PKHTB(tensor_v1_c20, tensor_v0_c20, 16); \ - uint32_t tensor_v10_c3 = __PKHTB(tensor_v1_c31, tensor_v0_c31, 16); \ - \ - sum_c0 = __SMLAD(tensor_v10_c0, kernel_v10_c0, sum_c0); \ - sum_c1 = __SMLAD(tensor_v10_c1, kernel_v10_c1, sum_c1); \ - sum_c2 = __SMLAD(tensor_v10_c2, kernel_v10_c2, sum_c2); \ - sum_c3 = __SMLAD(tensor_v10_c3, kernel_v10_c3, sum_c3); \ - }} - - - /* Here, we want to take the LOWER BYTES of 32 bit numbers v3 - v0 - * and concatenate them as "v3 || v2 || v1 || v0". In C++, this is: - - return ((sum_c0 & 0x000000FF)) + - ((sum_c1 & 0x000000FF) << 8) + - ((sum_c2 & 0x000000FF) << 16) + - ((sum_c3 & 0x000000FF) << 24); - - * Naively, this takes 4x ands, 3x adds, 3x shifts. With optimization flags, - * clang optimizes this down to eight operations: - - mov r12, #255 - and r0, r0, #255 - orr r12, r12, #65280 - and r1, r12, r1, lsl #8 - orr r0, r1, r0 - and r1, r2, #255 - orr r0, r0, r1, lsl #16 - orr r0, r0, r3, lsl #24 - - * But being clever engineers, we can do it in four instructions. I think, - * but have been unable to prove, that fewer is impossible. */ - - #define WRITE_QUAD_BYTE_JOIN_DSP(out, v3, v2, v1, v0) {{ \ - uint32_t v3_00_v1_00 = PKHBT(v1 << 8, v3, 24); \ - uint32_t gg_v2_gg_v0 = PKHBT(v0, v2, 16); \ - out[0] = UXTAB16(v3_00_v1_00, gg_v2_gg_v0); \ - }} - - /* We do four channels at once to get this speed boost. */ - extern "C" int kernel_convolve_noedge_{P_H}_{P_W}_{C}_{K_H}_{K_W}_{suffix}( - uint32_t *out, - uint32_t *tensor, - uint32_t *packed_kernel) {{ - - uint32_t sum_c0 = 0; - uint32_t sum_c1 = 0; - uint32_t sum_c2 = 0; - uint32_t sum_c3 = 0; - - QUAD_CHANNEL_TENSOR_REARRANGE_SUM( - packed_kernel, *tensor, *(tensor + {C // 4}), - sum_c0, sum_c1, sum_c2, sum_c3) - QUAD_CHANNEL_TENSOR_REARRANGE_SUM( - packed_kernel, *(tensor + {(2) * C // 4}), *(tensor + {P_W * (C // 4)}), - sum_c0, sum_c1, sum_c2, sum_c3) - QUAD_CHANNEL_TENSOR_REARRANGE_SUM( - packed_kernel, *(tensor + {(P_W + 1) * (C // 4)}), *(tensor + {(P_W + 2) * (C // 4)}), - sum_c0, sum_c1, sum_c2, sum_c3) - QUAD_CHANNEL_TENSOR_REARRANGE_SUM( - packed_kernel, *(tensor + {(2 * P_W) * (C // 4)}), *(tensor + {(2 * P_W + 1) * (C // 4)}), - sum_c0, sum_c1, sum_c2, sum_c3) - QUAD_CHANNEL_TENSOR_REARRANGE_SUM( - packed_kernel, *(tensor + {(2 * P_W + 2) * (C // 4)}), 0, - sum_c0, sum_c1, sum_c2, sum_c3) - - WRITE_QUAD_BYTE_JOIN_DSP(out, sum_c3, sum_c2, sum_c1, sum_c0); - }} - """ - ), - ) + + traverse_inline(schedule, outs[-1].op, _callback) + return schedule diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py new file mode 100644 index 000000000000..70b166aed5b2 --- /dev/null +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This is a special intrinsic used for depthwise convolution using Cortex-M DSP instructions +(v7e-m). It takes as inputs an int8 HWC data tensor and an int8 CHWc kernel. This intrinsic "lays" +the kernel on top of the data tensors tarting from a given pointer, performs signed sixteen-bit +multiplies on each pair of values, and sums all the products in an int32 accumlator. This process is +repeated four times giving four int32 outputs - one per channel.""" + +import textwrap + +from tvm import te, tir + +def intrin_quad_channel_convolve(tensor_w, channels, kernel_h, kernel_w, suffix): + data_slice = te.placeholder((kernel_h, kernel_w, 4), name="a", dtype="int8") + + if kernel_h * kernel_w % 2 == 1: + kernel_length = kernel_h * kernel_w + 1 + else: + kernel_length = kernel_h * kernel_w + kernel_slice = te.placeholder((kernel_length, 4), name="b", dtype="int8") + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + + output_slice = te.compute( + (4,), + lambda k: te.sum( + data_slice[kh_i, kw_i, k].astype("int32") + * kernel_slice[ + (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), + (2 * ((kh_i + kw_i) % 2)) + (k % 2), + ].astype("int32"), + axis=(kh_i, kw_i), + ), + name="c", + ) + + data_buf = tir.decl_buffer( + data_slice.shape, + data_slice.dtype, + name="data", + offset_factor=1, + strides=[tensor_w * channels, channels, 1], + ) + kernel_buf = tir.decl_buffer( + kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, strides=[4, 1] + ) + output_buf = tir.decl_buffer( + output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] + ) + + def intrin_func(ins, outs): + builder = tir.ir_builder.create() + builder.emit( + tir.call_extern( + "int32", + f"kernel_convolve_{tensor_w}_{channels}_{kernel_h}_{kernel_w}_{suffix}", + outs[0].access_ptr("w"), + ins[0].access_ptr("r"), + ins[1].access_ptr("r"), + ) + ) + return builder.get() + + return te.decl_tensor_intrin( + output_slice.op, + intrin_func, + binds={data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, + ) + + +def quad_channel_convolve_impl(tensor_w, channels, kernel_h, kernel_w, suffix): + # intrin_quad_channel_convolve supports any kernel size, but this function only supports + # 3x3 kernels (though this could be fixed with work). + assert kernel_h == kernel_w == 3 + + return textwrap.dedent( + ( + f""" + #include + #include + + // __SXTB16(_ROR(X, Y)) is combined into one assembly instruction + + #define TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( \ + arranged_kernel, \ + tensor_v0_c3210, tensor_v1_c3210, \ + sum0, sum1, sum2, sum3) {{ \ + \ + uint32_t tensor_v0_c20 = __SXTB16(tensor_v0_c3210); \ + uint32_t tensor_v0_c31 = __SXTB16(__ROR(tensor_v0_c3210, 8)); \ + uint32_t tensor_v1_c20 = __SXTB16(tensor_v1_c3210); \ + uint32_t tensor_v1_c31 = __SXTB16(__ROR(tensor_v1_c3210, 8)); \ + \ + uint32_t kernel_v1c1_v1c0_v0c1_v0c0 = *arranged_kernel++; \ + uint32_t kernel_v1c3_v1c2_v0c3_v0c2 = *arranged_kernel++; \ + \ + uint32_t kernel_v10_c0 = __SXTB16(kernel_v1c1_v1c0_v0c1_v0c0); \ + uint32_t kernel_v10_c1 = __SXTB16(__ROR(kernel_v1c1_v1c0_v0c1_v0c0, 8)); \ + uint32_t kernel_v10_c2 = __SXTB16(kernel_v1c3_v1c2_v0c3_v0c2); \ + uint32_t kernel_v10_c3 = __SXTB16(__ROR(kernel_v1c3_v1c2_v0c3_v0c2, 8)); \ + \ + uint32_t tensor_v10_c0 = __PKHBT(tensor_v0_c20, tensor_v1_c20, 16); \ + uint32_t tensor_v10_c1 = __PKHBT(tensor_v0_c31, tensor_v1_c31, 16); \ + uint32_t tensor_v10_c2 = __PKHTB(tensor_v1_c20, tensor_v0_c20, 16); \ + uint32_t tensor_v10_c3 = __PKHTB(tensor_v1_c31, tensor_v0_c31, 16); \ + \ + sum_c0 = __SMLAD(tensor_v10_c0, kernel_v10_c0, sum_c0); \ + sum_c1 = __SMLAD(tensor_v10_c1, kernel_v10_c1, sum_c1); \ + sum_c2 = __SMLAD(tensor_v10_c2, kernel_v10_c2, sum_c2); \ + sum_c3 = __SMLAD(tensor_v10_c3, kernel_v10_c3, sum_c3); \ + }} + + /* We do four channels at once to get this speed boost. */ + #ifdef __cplusplus + extern "C" + #endif + int32_t kernel_convolve_{tensor_w}_{channels}_{kernel_h}_{kernel_w}_{suffix}( + uint32_t *out, + uint32_t *tensor, + uint32_t *packed_kernel) {{ + + uint32_t sum_c0 = 0; + uint32_t sum_c1 = 0; + uint32_t sum_c2 = 0; + uint32_t sum_c3 = 0; + + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *tensor, + *(tensor + {channels // 4}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(2) * channels // 4}), + *(tensor + {tensor_w * (channels // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(tensor_w + 1) * (channels // 4)}), + *(tensor + {(tensor_w + 2) * (channels // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(2 * tensor_w) * (channels // 4)}), + *(tensor + {(2 * tensor_w + 1) * (channels // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(2 * tensor_w + 2) * (channels // 4)}), + 0, + sum_c0, sum_c1, sum_c2, sum_c3) + + out[0] = sum_c0; + out[1] = sum_c1; + out[2] = sum_c2; + out[3] = sum_c3; + return 0; + }} + + #undef TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP + """ + ) + ) diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index ee0d51c321f7..c3f3660e53fc 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -147,5 +147,29 @@ class TestDepthwiseConv2d_NHWC_HWOI(BasicDepthwiseConv2dTests): schedule_name = tvm.testing.parameter("depthwise_conv2d_nhwc.generic") +class TestDepthwiseConv2d_NHWC_HWOI_DSP(BasicDepthwiseConv2dTests): + """This test is for depthwise_conv2d_nhwc.generic schedule.""" + + data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( + # Depthwise_conv2d parameters from MobileNetV1 0.25x + ((1, 48, 48, 8), (3, 3), 8, (1, 1), 1, 1), + ((1, 48, 48, 16), (3, 3), 16, (2, 2), 1, 1), + ((1, 24, 24, 32), (3, 3), 32, (1, 1), 1, 1), + ((1, 24, 24, 32), (3, 3), 32, (2, 2), 1, 1), + ((1, 12, 12, 64), (3, 3), 64, (1, 1), 1, 1), + ((1, 12, 12, 64), (3, 3), 64, (2, 2), 1, 1), + ((1, 6, 6, 128), (3, 3), 128, (1, 1), 1, 1), + ((1, 6, 6, 128), (3, 3), 128, (2, 2), 1, 1), + ((1, 3, 3, 256), (3, 3), 256, (1, 1), 1, 1), + + # Asymmetric height and width + ((1, 25, 5, 64), (3, 3), 64, (1, 1), 1, 1), + ) + data_layout = tvm.testing.parameter("NHWC") + dtype = tvm.testing.parameter("int8") + kernel_layout = tvm.testing.parameter("HWOI") + schedule_name = tvm.testing.parameter("depthwise_conv2d_nhwc_dsp.arm_cpu") + + if __name__ == "__main__": tvm.testing.main() From fc7f82575b4c7e5995999d4a0f5208bef8edd99b Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Wed, 31 Aug 2022 11:33:32 -0700 Subject: [PATCH 3/6] Add support for fully custom padding --- python/tvm/relay/op/strategy/arm_cpu.py | 2 +- .../arm_cpu/mprofile/dsp/depthwise_conv2d.py | 22 ++++++++++++++----- .../dsp/micro_kernel/quad_channel_convolve.py | 6 +++-- .../strategy/arm_cpu/test_depthwise_conv2d.py | 13 ++++++----- 4 files changed, 29 insertions(+), 14 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 82f3a2410e0f..6d19982c995f 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -247,7 +247,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): target.features.has_dsp and kernel.shape[0] == kernel.shape[1] == 3 and dilation_w == dilation_h == 1 - and kernel.shape[3] == 1 # channel_multiplier == 1 + and kernel.shape[3] == 1 # channel_multiplier == 1 and data.dtype == "int8" and data.shape[3] % 4 == 0 ): diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py index 3fe67507a41b..1d5d0efc27fc 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py @@ -110,7 +110,7 @@ def depthwise_conv2d_nhwc_dsp(*args, **kwargs): cfg = autotvm.get_config() args = [cfg] + args assert layout == "NHWC" - conv = depthwise_conv2d_nhwc_dsp_compute(args) + conv = depthwise_conv2d_nhwc_dsp_compute(*args) sched = depthwise_conv2d_nhwc_dsp_schedule(cfg, [data, kernel, conv]) return sched, [data, kernel, conv] @@ -159,10 +159,6 @@ def depthwise_conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilat # to QUAD_CHANNEL_REARRANGE_SUM. assert kernel_w == kernel_h == 3 - # We do not currently support custom padding. Would be pretty easy to implement. - # assert padding == "SAME" or padding == "VALID" - padding = "SAME" - # Padding the data requires COPYING THE ENTIRE INPUT TENSOR, which # is slow and bad. We should really implement a strip mining # routine to avoid this, but TVM has terrible support for that. @@ -192,6 +188,22 @@ def depthwise_conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilat output_w = (width - kernel_w) // stride_w + 1 padded_data = data + elif isinstance(padding, tuple): + if len(padding) == 2: + pad_up, pad_down = padding[0] + pad_left, pad_right = padding[1] + else: + pad_up, pad_left, pad_down, pad_right = padding + + output_h = (height - kernel_h + pad_up + pad_down) // stride_h + 1 + output_w = (width - kernel_w + pad_left + pad_right) // stride_w + 1 + padded_data = pad( + data, + [0, pad_up, pad_left, 0], + [0, pad_down, pad_right, 0], + name="padded_data", + ) + else: raise RuntimeError() _, padded_h, padded_w, _ = padded_data.shape diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py index 70b166aed5b2..357d677a7a6d 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py @@ -24,7 +24,9 @@ from tvm import te, tir + def intrin_quad_channel_convolve(tensor_w, channels, kernel_h, kernel_w, suffix): + """Defines a v7e-m DSP-accelerated four-channel convolution.""" data_slice = te.placeholder((kernel_h, kernel_w, 4), name="a", dtype="int8") if kernel_h * kernel_w % 2 == 1: @@ -84,8 +86,8 @@ def intrin_func(ins, outs): def quad_channel_convolve_impl(tensor_w, channels, kernel_h, kernel_w, suffix): - # intrin_quad_channel_convolve supports any kernel size, but this function only supports - # 3x3 kernels (though this could be fixed with work). + """Emits C code for quad_channel_convolve. Note that while intrin_quad_channel_convolve supports + any kernel size, this function only supports 3x3 kernels (though this could be fixed with work).""" assert kernel_h == kernel_w == 3 return textwrap.dedent( diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index c3f3660e53fc..0dd715e5fb93 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -151,17 +151,18 @@ class TestDepthwiseConv2d_NHWC_HWOI_DSP(BasicDepthwiseConv2dTests): """This test is for depthwise_conv2d_nhwc.generic schedule.""" data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( + # The LLVM implementation doesn't support "SAME" and "VALID" padding, + # so padding must be explicitly specified. # Depthwise_conv2d parameters from MobileNetV1 0.25x - ((1, 48, 48, 8), (3, 3), 8, (1, 1), 1, 1), - ((1, 48, 48, 16), (3, 3), 16, (2, 2), 1, 1), + ((1, 48, 48, 8), (3, 3), 8, (1, 1), 1, 1), + ((1, 48, 48, 16), (3, 3), 16, (2, 2), (1, 1, 0, 0), 1), ((1, 24, 24, 32), (3, 3), 32, (1, 1), 1, 1), - ((1, 24, 24, 32), (3, 3), 32, (2, 2), 1, 1), + ((1, 24, 24, 32), (3, 3), 32, (2, 2), (1, 1, 0, 0), 1), ((1, 12, 12, 64), (3, 3), 64, (1, 1), 1, 1), - ((1, 12, 12, 64), (3, 3), 64, (2, 2), 1, 1), + ((1, 12, 12, 64), (3, 3), 64, (2, 2), (1, 1, 0, 0), 1), ((1, 6, 6, 128), (3, 3), 128, (1, 1), 1, 1), - ((1, 6, 6, 128), (3, 3), 128, (2, 2), 1, 1), + ((1, 6, 6, 128), (3, 3), 128, (2, 2), (1, 1, 0, 0), 1), ((1, 3, 3, 256), (3, 3), 256, (1, 1), 1, 1), - # Asymmetric height and width ((1, 25, 5, 64), (3, 3), 64, (1, 1), 1, 1), ) From 437dac43b6954a749a34637508e4dbfb6e1869d1 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Wed, 31 Aug 2022 21:58:59 -0700 Subject: [PATCH 4/6] Fix pylint --- .../arm_cpu/mprofile/dsp/depthwise_conv2d.py | 27 +++---------------- .../dsp/micro_kernel/quad_channel_convolve.py | 2 +- 2 files changed, 4 insertions(+), 25 deletions(-) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py index 1d5d0efc27fc..ede822da76b3 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py @@ -19,8 +19,6 @@ import random import string -from tvm import autotvm -from tvm.autotvm.task import deserialize_args from tvm import te from tvm.topi.utils import traverse_inline, get_const_tuple from tvm.topi.nn.pad import pad @@ -101,28 +99,9 @@ def fcompute(c_o, pos, c_i): ) -def depthwise_conv2d_nhwc_dsp(*args, **kwargs): - """Defines the v7e-m DSP instructions of depthwise_conv2d.""" - assert not kwargs, "Do not support kwargs in template function call" - args = deserialize_args(args) - data, kernel = args[:2] - layout = args[-2] - cfg = autotvm.get_config() - args = [cfg] + args - assert layout == "NHWC" - conv = depthwise_conv2d_nhwc_dsp_compute(*args) - sched = depthwise_conv2d_nhwc_dsp_schedule(cfg, [data, kernel, conv]) - return sched, [data, kernel, conv] - - -depthwise_conv2d_nhwc_dsp.template_key = "dsp" -depthwise_conv2d_nhwc_dsp.default_data_layout = "NHWC" -depthwise_conv2d_nhwc_dsp.default_kernel_layout = "HWOI" - - -def depthwise_conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilation, out_dtype): +def depthwise_conv2d_nhwc_dsp_compute(_cfg, data, kernel, strides, padding, dilation, out_dtype): """Compute function for v7e-m DSP instructions of DepthwiseConv2D. Has a lot of requirements - for use - not not all apply, the fallback implementation will be used instead.""" + for use - if not all apply, the fallback implementation will be used instead.""" assert isinstance(strides, int) or len(strides) == 2 assert isinstance(dilation, int) or len(dilation) == 2 @@ -227,7 +206,7 @@ def depthwise_conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilat ) -def depthwise_conv2d_nhwc_dsp_schedule(cfg, outs): +def depthwise_conv2d_nhwc_dsp_schedule(_cfg, outs): """Schedule function for v7e-m DSP instructions of conv2d.""" schedule = te.create_schedule([x.op for x in outs]) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py index 357d677a7a6d..b83c0d4d00c9 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py @@ -87,7 +87,7 @@ def intrin_func(ins, outs): def quad_channel_convolve_impl(tensor_w, channels, kernel_h, kernel_w, suffix): """Emits C code for quad_channel_convolve. Note that while intrin_quad_channel_convolve supports - any kernel size, this function only supports 3x3 kernels (though this could be fixed with work).""" + any kernel size, this function only supports 3x3 kernels (this could be fixed with work).""" assert kernel_h == kernel_w == 3 return textwrap.dedent( From 611778c089433c60369c6817435be48dd8f9f398 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Thu, 1 Sep 2022 10:54:44 -0700 Subject: [PATCH 5/6] Fix comments on PR --- .../arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py | 2 +- tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py index b83c0d4d00c9..4d8536866d47 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py @@ -16,7 +16,7 @@ # under the License. """This is a special intrinsic used for depthwise convolution using Cortex-M DSP instructions (v7e-m). It takes as inputs an int8 HWC data tensor and an int8 CHWc kernel. This intrinsic "lays" -the kernel on top of the data tensors tarting from a given pointer, performs signed sixteen-bit +the kernel on top of the data tensors starting from a given pointer, performs signed sixteen-bit multiplies on each pair of values, and sums all the products in an int32 accumlator. This process is repeated four times giving four int32 outputs - one per channel.""" diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index 0dd715e5fb93..18c5082f2a0c 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -148,7 +148,7 @@ class TestDepthwiseConv2d_NHWC_HWOI(BasicDepthwiseConv2dTests): class TestDepthwiseConv2d_NHWC_HWOI_DSP(BasicDepthwiseConv2dTests): - """This test is for depthwise_conv2d_nhwc.generic schedule.""" + """This test is for depthwise_conv2d_nhwc_dsp.arm_cpu schedule.""" data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( # The LLVM implementation doesn't support "SAME" and "VALID" padding, From 33fadf8b98be4f87467fb9afa7b472442aee2db6 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Fri, 2 Sep 2022 11:02:10 -0700 Subject: [PATCH 6/6] Address comments from Ashutosh --- python/tvm/relay/op/strategy/arm_cpu.py | 10 +++++----- .../tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py | 5 +++-- .../mprofile/dsp/micro_kernel/quad_channel_convolve.py | 4 ++-- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 6d19982c995f..2d9ef99ba8a6 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -237,11 +237,9 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): ) # Optimized special case depthwiseConv2D operation. Requires a 3x3 kernel, a - # NHWC layout, a HWOI kernel layout (which we would ideally rearrange), no dilation, - # "SAME" padding, int8 inputs and outputs, the same number of input and output - # channels, and for that channel count to be divisible by 4. - # - # Additional work could remove some of these restrictions. + # NHWC layout, a HWOI kernel layout (which we rearrange), no dilation, int8 inputs, + # int32 output, the same number of input and output channels, and for that channel + # count to be divisible by 4. Additional work could remove these restrictions. elif ( target.features.has_dsp @@ -249,7 +247,9 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): and dilation_w == dilation_h == 1 and kernel.shape[3] == 1 # channel_multiplier == 1 and data.dtype == "int8" + and out_type.dtype == "int32" and data.shape[3] % 4 == 0 + and (padding != "SAME" or data.shape[1] % stride_h == data.shape[2] % stride_w == 0) ): strategy.add_implementation( wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nhwc_dsp), diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py index ede822da76b3..162bf65a21f9 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py @@ -149,8 +149,9 @@ def depthwise_conv2d_nhwc_dsp_compute(_cfg, data, kernel, strides, padding, dila output_h = height // stride_h output_w = width // stride_w - # Note - this padding behavior is DIFFERENT from Tensorflow, which pads the top left if - # stride > 1. Need to investigate and decide which behavior we want. + # This padding behavior is consistent with other TVM depthwise_conv2d schedules. However it + # differs from the TensorFlow, which only pads the bottom right if stride > 1. This probably + # brings down accuracy slightly for models imported from TFLite. pad_down = 1 if stride_h == 1 else 0 pad_right = 1 if stride_w == 1 else 0 diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py index 4d8536866d47..960ef8fadc0e 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py @@ -70,7 +70,7 @@ def intrin_func(ins, outs): builder.emit( tir.call_extern( "int32", - f"kernel_convolve_{tensor_w}_{channels}_{kernel_h}_{kernel_w}_{suffix}", + f"kernel_convolve_w{tensor_w}_c{channels}_kh{kernel_h}_kw{kernel_w}_{suffix}", outs[0].access_ptr("w"), ins[0].access_ptr("r"), ins[1].access_ptr("r"), @@ -131,7 +131,7 @@ def quad_channel_convolve_impl(tensor_w, channels, kernel_h, kernel_w, suffix): #ifdef __cplusplus extern "C" #endif - int32_t kernel_convolve_{tensor_w}_{channels}_{kernel_h}_{kernel_w}_{suffix}( + int32_t kernel_convolve_w{tensor_w}_c{channels}_kh{kernel_h}_kw{kernel_w}_{suffix}( uint32_t *out, uint32_t *tensor, uint32_t *packed_kernel) {{