From 27e0663742c2a671c92892f9bf48c47c07a2f425 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Tue, 30 Aug 2022 06:58:40 -0700 Subject: [PATCH] Code reorganization and testing Functional DWC2D schedule with test Code cleanup and linting Fix padding to match Relay and add tests Fix test cases --- python/tvm/relay/op/strategy/arm_cpu.py | 7 +- .../arm_cpu/mprofile/dsp/depthwise_conv2d.py | 275 +++++------------- .../dsp/micro_kernel/quad_channel_convolve.py | 178 ++++++++++++ .../strategy/arm_cpu/test_depthwise_conv2d.py | 24 ++ 4 files changed, 280 insertions(+), 204 deletions(-) create mode 100644 python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 15cc76885f53..82f3a2410e0f 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -244,17 +244,16 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): # Additional work could remove some of these restrictions. elif ( - isa.has_dsp_support + target.features.has_dsp and kernel.shape[0] == kernel.shape[1] == 3 and dilation_w == dilation_h == 1 and kernel.shape[3] == 1 # channel_multiplier == 1 and data.dtype == "int8" - and padding == "SAME" and data.shape[3] % 4 == 0 ): strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nhwc), - wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc), + wrap_compute_conv2d(topi.arm_cpu.depthwise_conv2d_nhwc_dsp), + wrap_topi_schedule(topi.arm_cpu.schedule_depthwise_conv2d_nhwc_dsp), name="depthwise_conv2d_nhwc_dsp.arm_cpu", ) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py index c023993e91d1..3fe67507a41b 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/depthwise_conv2d.py @@ -14,16 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""ARM Cortex-M DSP schedule for depthwise_conv2d""" -"""Direct implementation of conv2d.""" +import random +import string from tvm import autotvm from tvm.autotvm.task import deserialize_args from tvm import te -from tvm.topi.utils import simplify, traverse_inline +from tvm.topi.utils import traverse_inline, get_const_tuple from tvm.topi.nn.pad import pad -from tvm.topi.nn.utils import get_pad_tuple -from tvm.tir.expr import Mul +from tvm import tir + +from .micro_kernel.quad_channel_convolve import ( + intrin_quad_channel_convolve, + quad_channel_convolve_impl, +) # For depthwise_conv2d, kernels are normally given in HWOI format, # which when input_channels = output channels, we will call HWC. @@ -71,24 +77,27 @@ def _rearrange_kernel(kernel): # Kernel must be HWC format. - K_H, K_W, C, _ = get_const_tuple(kernel.shape) - assert C % 4 == 0 + kernel_h, kernel_w, channels, _ = get_const_tuple(kernel.shape) + assert channels % 4 == 0 - # TODO remove this restriction - assert (K_W * K_H) % 2 == 1 + # This restriction could be removed by only using tir.if_then_else to add padding + # zeros if (kernel_w * kernel_h) % 2 == 1, and filling completely otherwise. + assert (kernel_w * kernel_h) % 2 == 1 def fcompute(c_o, pos, c_i): channel = (2 * (pos % 2)) + (c_i % 2) + (4 * c_o) true_pos_index = 2 * (pos // 2) + (c_i // 2) return tir.if_then_else( - true_pos_index < (K_H * K_W), - kernel[true_pos_index // K_W, true_pos_index % K_W, channel, 0], + true_pos_index < (kernel_h * kernel_w), + kernel[true_pos_index // kernel_w, true_pos_index % kernel_w, channel, 0], tir.const(0, "int8"), ) return te.compute( - (C // 4, K_H * K_W + 1, 4), lambda co, pos, ci: fcompute(co, pos, ci), name="packed_kernel" + (channels // 4, kernel_h * kernel_w + 1, 4), + fcompute, + name="packed_kernel", ) @@ -101,7 +110,7 @@ def depthwise_conv2d_nhwc_dsp(*args, **kwargs): cfg = autotvm.get_config() args = [cfg] + args assert layout == "NHWC" - conv = depthwise_conv2d_nhwc_dsp_compute(*args) + conv = depthwise_conv2d_nhwc_dsp_compute(args) sched = depthwise_conv2d_nhwc_dsp_schedule(cfg, [data, kernel, conv]) return sched, [data, kernel, conv] @@ -118,40 +127,41 @@ def depthwise_conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilat assert isinstance(dilation, int) or len(dilation) == 2 if isinstance(strides, int): - STRIDE_H = STRIDE_W = strides + stride_h = stride_w = strides else: - STRIDE_H, STRIDE_W = strides + stride_h, stride_w = strides # We do not support dilation currently. It would be possible, but it would require # modifying the way the kernel is packed. Gnarly. if isinstance(dilation, int): - DILATION_H = DILATION_W = dilation + dilation_h = dilation_w = dilation else: - DILATION_H, DILATION_H = dilation - assert DILATION_H == DILATION_H == 1 + dilation_h, dilation_w = dilation + assert dilation_h == dilation_w == 1 - B, H, W, C = data.shape - K_H, K_W, _, _ = kernel.shape + batch_size, height, width, channels = data.shape + kernel_h, kernel_w, _, _ = kernel.shape # We require that the number of channels be divisible by 4. This restriction could # be removed with strip mining if people cared. - assert C % 4 == 0 + assert channels % 4 == 0 # We don't support different numbers of input and output channels. - assert C == kernel.shape[2] + assert channels == kernel.shape[2] assert kernel.shape[3] == 1 - # The int16 case could also be optimized here, but would require writing a whole new - # micro kernel. Probably not worth it. - assert out_dtype == "int8" + # We take in int8 as our dtype, but we spit out int32. This is because we cannot + # round until we compute activations. + assert out_dtype == "int32" # This can pretty easily be generalized in the future. Likely worth doing, and this # function was written to make doing so easy. Should only require adding more calls # to QUAD_CHANNEL_REARRANGE_SUM. - assert K_W == K_H == 3 + assert kernel_w == kernel_h == 3 # We do not currently support custom padding. Would be pretty easy to implement. - assert padding == "SAME" or padding == "VALID" + # assert padding == "SAME" or padding == "VALID" + padding = "SAME" # Padding the data requires COPYING THE ENTIRE INPUT TENSOR, which # is slow and bad. We should really implement a strip mining @@ -159,45 +169,49 @@ def depthwise_conv2d_nhwc_dsp_compute(cfg, data, kernel, strides, padding, dilat if padding == "SAME": # This assumption makes the logic easier. Could be removed with work. - assert H % STRIDE_H == W % STRIDE_W == 0 + assert height % stride_h == width % stride_w == 0 - OUT_H = H // STRIDE_H - OUT_W = W // STRIDE_W + output_h = height // stride_h + output_w = width // stride_w - # Padding this way is weird and quirky, and we only do it to match TFLite. - pad_top = 1 if STRIDE_H == 1 else 0 - pad_left = 1 if STRIDE_W == 1 else 0 + # Note - this padding behavior is DIFFERENT from Tensorflow, which pads the top left if + # stride > 1. Need to investigate and decide which behavior we want. + pad_down = 1 if stride_h == 1 else 0 + pad_right = 1 if stride_w == 1 else 0 - data_padded = pad( - data, [0, pad_top, pad_left, 0], [0, K_H // 2, K_W // 2, 0], name="padded_data" + padded_data = pad( + data, + [0, kernel_h // 2, kernel_w // 2, 0], + [0, pad_down, pad_right, 0], + name="padded_data", ) - elif PADDING_STRATEGY == "VALID": - assert H > K_H and W > K_W - OUT_H = (H - K_H) // STRIDE_H + 1 - OUT_W = (W - K_W) // STRIDE_W + 1 - data_padded = data + elif padding == "VALID": + assert height > kernel_h and width > kernel_w + output_h = (height - kernel_h) // stride_h + 1 + output_w = (width - kernel_w) // stride_w + 1 + padded_data = data else: raise RuntimeError() - _, P_H, P_W, _ = data_padded.shape + _, padded_h, padded_w, _ = padded_data.shape packed_kernel = _rearrange_kernel(kernel) - kh_i = te.reduce_axis((0, K_H), name="kh_i") - kw_i = te.reduce_axis((0, K_W), name="kw_i") + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") return te.compute( - (B, OUT_H, OUT_W, C), + (batch_size, output_h, output_w, channels), lambda h, i, j, k: te.sum( - DATA_PAD[h, (i * STRIDE_H) + kh_i, (j * STRIDE_W) + kw_i, k] - * PACKED_KER[ + padded_data[h, (i * stride_h) + kh_i, (j * stride_w) + kw_i, k].astype("int32") + * packed_kernel[ k // 4, (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), (2 * ((kh_i + kw_i) % 2)) + (k % 2), - ], + ].astype("int32"), axis=(kh_i, kw_i), ), name="depthwise_conv2d", - tag=f"depthwise_conv2d_nhwc_{P_H}_{P_W}_dsp", + tag=f"depthwise_conv2d_nhwc_{padded_h}_{padded_w}_dsp", ) @@ -212,12 +226,12 @@ def _callback(op): # extract tensors output = op.output(0) - padded_data = conv_out.op.input_tensors[0] - packed_kernel = conv_out.op.input_tensors[1] + padded_data = output.op.input_tensors[0] + packed_kernel = output.op.input_tensors[1] kernel = packed_kernel.op.input_tensors[0] - B, P_H, P_W, C = padded_data.shape - K_H, K_W, _, _ = kernel.shape + _, _, padded_w, channels = padded_data.shape + kernel_h, kernel_w, _, _ = kernel.shape suffix = "".join(random.choices(string.ascii_uppercase, k=8)) b_ax, y_ax, x_ax, c_ax = schedule[output].op.axis @@ -225,154 +239,15 @@ def _callback(op): c_ax_o, c_ax_i = schedule[output].split(c_ax, factor=4) schedule[output].reorder(b_ax, c_ax_o, y_ax, x_ax, ky_ax, kx_ax, c_ax_i) - quad_channel_convolve = intrin_quad_channel_convolve_3x3(P_H, P_W, C, K_H, K_W, suffix) - s[CONVOLVED].tensorize(ky_ax, gemv) - sched[output].pragma( - b_ax, "import_c", quad_channel_convolve_3x3_impl(P_H, P_W, C, K_H, K_W, suffix) + quad_channel_convolve = intrin_quad_channel_convolve( + padded_w, channels, kernel_h, kernel_w, suffix ) - - traverse_inline(sched, outs[-1].op, _callback) - return sched - - -def intrin_quad_channel_convolve_3x3(P_H, P_W, C, K_H, K_W, suffix): - a = te.placeholder((K_H, K_W, 4), name="a", dtype="int8") - b = te.placeholder((K_H * K_W + 1, 4), name="b", dtype="int8") - kh_i = te.reduce_axis((0, K_H), name="kh_i") - kw_i = te.reduce_axis((0, K_W), name="kw_i") - - c = te.compute( - (4,), - lambda k: te.sum( - a[kh_i, kw_i, k] - * b[ - (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), - (2 * ((kh_i + kw_i) % 2)) + (k % 2), - ], - axis=(kh_i, kw_i), - ), - name="c", - ) - - Ab = tvm.tir.decl_buffer(a.shape, a.dtype, name="A", offset_factor=1, strides=[P_W * C, C, 1]) - Bb = tvm.tir.decl_buffer(b.shape, b.dtype, name="B", offset_factor=1, strides=[4, 1]) - Cb = tvm.tir.decl_buffer(c.shape, c.dtype, name="C", offset_factor=1, strides=[1]) - - def intrin_func(ins, outs): - ib = tvm.tir.ir_builder.create() - aa, bb = ins - cc = outs[0] - ib.emit( - tvm.tir.call_extern( - "int32", - f"kernel_convolve_noedge_{P_H}_{P_W}_{C}_{K_H}_{K_W}_{suffix}", - cc.access_ptr("w"), - aa.access_ptr("r"), - bb.access_ptr("r"), - ) + schedule[output].tensorize(ky_ax, quad_channel_convolve) + schedule[output].pragma( + b_ax, + "import_c", + quad_channel_convolve_impl(padded_w, channels, kernel_h, kernel_w, suffix), ) - return ib.get() - - return te.decl_tensor_intrin(c.op, intrin_func, binds={a: Ab, b: Bb, c: Cb}) - - -def quad_channel_convolve_3x3_impl(P_H, P_W, C, K_H, K_W, suffix): - return ( - textwrap.dedent( - f""" - - #include - - // __SXTB16(_ROR(X, Y)) is combined into one assembly instruction - - #define QUAD_CHANNEL_TENSOR_REARRANGE_SUM_DSP( \ - arranged_kernel, \ - tensor_v0_c3210, tensor_v1_c3210, \ - sum0, sum1, sum2, sum3) {{ \ - \ - uint32_t tensor_v0_c20 = __SXTB16(tensor_v0_c3210); \ - uint32_t tensor_v0_c31 = __SXTB16(__ROR(tensor_v0_c3210, 8)); \ - uint32_t tensor_v1_c20 = __SXTB16(tensor_v1_c3210); \ - uint32_t tensor_v1_c31 = __SXTB16(__ROR(tensor_v1_c3210, 8)); \ - \ - uint32_t kernel_v1c1_v1c0_v0c1_v0c0 = *arranged_kernel++; \ - uint32_t kernel_v1c3_v1c2_v0c3_v0c2 = *arranged_kernel++; \ - \ - uint32_t kernel_v10_c0 = __SXTB16(kernel_v1c1_v1c0_v0c1_v0c0); \ - uint32_t kernel_v10_c1 = __SXTB16(__ROR(kernel_v1c1_v1c0_v0c1_v0c0, 8)); \ - uint32_t kernel_v10_c2 = __SXTB16(kernel_v1c3_v1c2_v0c3_v0c2); \ - uint32_t kernel_v10_c3 = __SXTB16(__ROR(kernel_v1c3_v1c2_v0c3_v0c2, 8)); \ - \ - uint32_t tensor_v10_c0 = __PKHBT(tensor_v0_c20, tensor_v1_c20, 16); \ - uint32_t tensor_v10_c1 = __PKHBT(tensor_v0_c31, tensor_v1_c31, 16); \ - uint32_t tensor_v10_c2 = __PKHTB(tensor_v1_c20, tensor_v0_c20, 16); \ - uint32_t tensor_v10_c3 = __PKHTB(tensor_v1_c31, tensor_v0_c31, 16); \ - \ - sum_c0 = __SMLAD(tensor_v10_c0, kernel_v10_c0, sum_c0); \ - sum_c1 = __SMLAD(tensor_v10_c1, kernel_v10_c1, sum_c1); \ - sum_c2 = __SMLAD(tensor_v10_c2, kernel_v10_c2, sum_c2); \ - sum_c3 = __SMLAD(tensor_v10_c3, kernel_v10_c3, sum_c3); \ - }} - - - /* Here, we want to take the LOWER BYTES of 32 bit numbers v3 - v0 - * and concatenate them as "v3 || v2 || v1 || v0". In C++, this is: - - return ((sum_c0 & 0x000000FF)) + - ((sum_c1 & 0x000000FF) << 8) + - ((sum_c2 & 0x000000FF) << 16) + - ((sum_c3 & 0x000000FF) << 24); - - * Naively, this takes 4x ands, 3x adds, 3x shifts. With optimization flags, - * clang optimizes this down to eight operations: - - mov r12, #255 - and r0, r0, #255 - orr r12, r12, #65280 - and r1, r12, r1, lsl #8 - orr r0, r1, r0 - and r1, r2, #255 - orr r0, r0, r1, lsl #16 - orr r0, r0, r3, lsl #24 - - * But being clever engineers, we can do it in four instructions. I think, - * but have been unable to prove, that fewer is impossible. */ - - #define WRITE_QUAD_BYTE_JOIN_DSP(out, v3, v2, v1, v0) {{ \ - uint32_t v3_00_v1_00 = PKHBT(v1 << 8, v3, 24); \ - uint32_t gg_v2_gg_v0 = PKHBT(v0, v2, 16); \ - out[0] = UXTAB16(v3_00_v1_00, gg_v2_gg_v0); \ - }} - - /* We do four channels at once to get this speed boost. */ - extern "C" int kernel_convolve_noedge_{P_H}_{P_W}_{C}_{K_H}_{K_W}_{suffix}( - uint32_t *out, - uint32_t *tensor, - uint32_t *packed_kernel) {{ - - uint32_t sum_c0 = 0; - uint32_t sum_c1 = 0; - uint32_t sum_c2 = 0; - uint32_t sum_c3 = 0; - - QUAD_CHANNEL_TENSOR_REARRANGE_SUM( - packed_kernel, *tensor, *(tensor + {C // 4}), - sum_c0, sum_c1, sum_c2, sum_c3) - QUAD_CHANNEL_TENSOR_REARRANGE_SUM( - packed_kernel, *(tensor + {(2) * C // 4}), *(tensor + {P_W * (C // 4)}), - sum_c0, sum_c1, sum_c2, sum_c3) - QUAD_CHANNEL_TENSOR_REARRANGE_SUM( - packed_kernel, *(tensor + {(P_W + 1) * (C // 4)}), *(tensor + {(P_W + 2) * (C // 4)}), - sum_c0, sum_c1, sum_c2, sum_c3) - QUAD_CHANNEL_TENSOR_REARRANGE_SUM( - packed_kernel, *(tensor + {(2 * P_W) * (C // 4)}), *(tensor + {(2 * P_W + 1) * (C // 4)}), - sum_c0, sum_c1, sum_c2, sum_c3) - QUAD_CHANNEL_TENSOR_REARRANGE_SUM( - packed_kernel, *(tensor + {(2 * P_W + 2) * (C // 4)}), 0, - sum_c0, sum_c1, sum_c2, sum_c3) - - WRITE_QUAD_BYTE_JOIN_DSP(out, sum_c3, sum_c2, sum_c1, sum_c0); - }} - """ - ), - ) + + traverse_inline(schedule, outs[-1].op, _callback) + return schedule diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py new file mode 100644 index 000000000000..70b166aed5b2 --- /dev/null +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/quad_channel_convolve.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""This is a special intrinsic used for depthwise convolution using Cortex-M DSP instructions +(v7e-m). It takes as inputs an int8 HWC data tensor and an int8 CHWc kernel. This intrinsic "lays" +the kernel on top of the data tensors tarting from a given pointer, performs signed sixteen-bit +multiplies on each pair of values, and sums all the products in an int32 accumlator. This process is +repeated four times giving four int32 outputs - one per channel.""" + +import textwrap + +from tvm import te, tir + +def intrin_quad_channel_convolve(tensor_w, channels, kernel_h, kernel_w, suffix): + data_slice = te.placeholder((kernel_h, kernel_w, 4), name="a", dtype="int8") + + if kernel_h * kernel_w % 2 == 1: + kernel_length = kernel_h * kernel_w + 1 + else: + kernel_length = kernel_h * kernel_w + kernel_slice = te.placeholder((kernel_length, 4), name="b", dtype="int8") + + kh_i = te.reduce_axis((0, kernel_h), name="kh_i") + kw_i = te.reduce_axis((0, kernel_w), name="kw_i") + + output_slice = te.compute( + (4,), + lambda k: te.sum( + data_slice[kh_i, kw_i, k].astype("int32") + * kernel_slice[ + (2 * ((3 * kh_i + kw_i) // 2)) + ((k % 4) // 2), + (2 * ((kh_i + kw_i) % 2)) + (k % 2), + ].astype("int32"), + axis=(kh_i, kw_i), + ), + name="c", + ) + + data_buf = tir.decl_buffer( + data_slice.shape, + data_slice.dtype, + name="data", + offset_factor=1, + strides=[tensor_w * channels, channels, 1], + ) + kernel_buf = tir.decl_buffer( + kernel_slice.shape, kernel_slice.dtype, name="kernel", offset_factor=1, strides=[4, 1] + ) + output_buf = tir.decl_buffer( + output_slice.shape, output_slice.dtype, name="output", offset_factor=1, strides=[1] + ) + + def intrin_func(ins, outs): + builder = tir.ir_builder.create() + builder.emit( + tir.call_extern( + "int32", + f"kernel_convolve_{tensor_w}_{channels}_{kernel_h}_{kernel_w}_{suffix}", + outs[0].access_ptr("w"), + ins[0].access_ptr("r"), + ins[1].access_ptr("r"), + ) + ) + return builder.get() + + return te.decl_tensor_intrin( + output_slice.op, + intrin_func, + binds={data_slice: data_buf, kernel_slice: kernel_buf, output_slice: output_buf}, + ) + + +def quad_channel_convolve_impl(tensor_w, channels, kernel_h, kernel_w, suffix): + # intrin_quad_channel_convolve supports any kernel size, but this function only supports + # 3x3 kernels (though this could be fixed with work). + assert kernel_h == kernel_w == 3 + + return textwrap.dedent( + ( + f""" + #include + #include + + // __SXTB16(_ROR(X, Y)) is combined into one assembly instruction + + #define TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( \ + arranged_kernel, \ + tensor_v0_c3210, tensor_v1_c3210, \ + sum0, sum1, sum2, sum3) {{ \ + \ + uint32_t tensor_v0_c20 = __SXTB16(tensor_v0_c3210); \ + uint32_t tensor_v0_c31 = __SXTB16(__ROR(tensor_v0_c3210, 8)); \ + uint32_t tensor_v1_c20 = __SXTB16(tensor_v1_c3210); \ + uint32_t tensor_v1_c31 = __SXTB16(__ROR(tensor_v1_c3210, 8)); \ + \ + uint32_t kernel_v1c1_v1c0_v0c1_v0c0 = *arranged_kernel++; \ + uint32_t kernel_v1c3_v1c2_v0c3_v0c2 = *arranged_kernel++; \ + \ + uint32_t kernel_v10_c0 = __SXTB16(kernel_v1c1_v1c0_v0c1_v0c0); \ + uint32_t kernel_v10_c1 = __SXTB16(__ROR(kernel_v1c1_v1c0_v0c1_v0c0, 8)); \ + uint32_t kernel_v10_c2 = __SXTB16(kernel_v1c3_v1c2_v0c3_v0c2); \ + uint32_t kernel_v10_c3 = __SXTB16(__ROR(kernel_v1c3_v1c2_v0c3_v0c2, 8)); \ + \ + uint32_t tensor_v10_c0 = __PKHBT(tensor_v0_c20, tensor_v1_c20, 16); \ + uint32_t tensor_v10_c1 = __PKHBT(tensor_v0_c31, tensor_v1_c31, 16); \ + uint32_t tensor_v10_c2 = __PKHTB(tensor_v1_c20, tensor_v0_c20, 16); \ + uint32_t tensor_v10_c3 = __PKHTB(tensor_v1_c31, tensor_v0_c31, 16); \ + \ + sum_c0 = __SMLAD(tensor_v10_c0, kernel_v10_c0, sum_c0); \ + sum_c1 = __SMLAD(tensor_v10_c1, kernel_v10_c1, sum_c1); \ + sum_c2 = __SMLAD(tensor_v10_c2, kernel_v10_c2, sum_c2); \ + sum_c3 = __SMLAD(tensor_v10_c3, kernel_v10_c3, sum_c3); \ + }} + + /* We do four channels at once to get this speed boost. */ + #ifdef __cplusplus + extern "C" + #endif + int32_t kernel_convolve_{tensor_w}_{channels}_{kernel_h}_{kernel_w}_{suffix}( + uint32_t *out, + uint32_t *tensor, + uint32_t *packed_kernel) {{ + + uint32_t sum_c0 = 0; + uint32_t sum_c1 = 0; + uint32_t sum_c2 = 0; + uint32_t sum_c3 = 0; + + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *tensor, + *(tensor + {channels // 4}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(2) * channels // 4}), + *(tensor + {tensor_w * (channels // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(tensor_w + 1) * (channels // 4)}), + *(tensor + {(tensor_w + 2) * (channels // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(2 * tensor_w) * (channels // 4)}), + *(tensor + {(2 * tensor_w + 1) * (channels // 4)}), + sum_c0, sum_c1, sum_c2, sum_c3) + TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP( + packed_kernel, + *(tensor + {(2 * tensor_w + 2) * (channels // 4)}), + 0, + sum_c0, sum_c1, sum_c2, sum_c3) + + out[0] = sum_c0; + out[1] = sum_c1; + out[2] = sum_c2; + out[3] = sum_c3; + return 0; + }} + + #undef TVMGEN_QUAD_CHANNEL_REARRANGE_SUM_DSP + """ + ) + ) diff --git a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py index ee0d51c321f7..c3f3660e53fc 100644 --- a/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_depthwise_conv2d.py @@ -147,5 +147,29 @@ class TestDepthwiseConv2d_NHWC_HWOI(BasicDepthwiseConv2dTests): schedule_name = tvm.testing.parameter("depthwise_conv2d_nhwc.generic") +class TestDepthwiseConv2d_NHWC_HWOI_DSP(BasicDepthwiseConv2dTests): + """This test is for depthwise_conv2d_nhwc.generic schedule.""" + + data_shape, kernel_size, num_filter, strides, padding, dilation = tvm.testing.parameters( + # Depthwise_conv2d parameters from MobileNetV1 0.25x + ((1, 48, 48, 8), (3, 3), 8, (1, 1), 1, 1), + ((1, 48, 48, 16), (3, 3), 16, (2, 2), 1, 1), + ((1, 24, 24, 32), (3, 3), 32, (1, 1), 1, 1), + ((1, 24, 24, 32), (3, 3), 32, (2, 2), 1, 1), + ((1, 12, 12, 64), (3, 3), 64, (1, 1), 1, 1), + ((1, 12, 12, 64), (3, 3), 64, (2, 2), 1, 1), + ((1, 6, 6, 128), (3, 3), 128, (1, 1), 1, 1), + ((1, 6, 6, 128), (3, 3), 128, (2, 2), 1, 1), + ((1, 3, 3, 256), (3, 3), 256, (1, 1), 1, 1), + + # Asymmetric height and width + ((1, 25, 5, 64), (3, 3), 64, (1, 1), 1, 1), + ) + data_layout = tvm.testing.parameter("NHWC") + dtype = tvm.testing.parameter("int8") + kernel_layout = tvm.testing.parameter("HWOI") + schedule_name = tvm.testing.parameter("depthwise_conv2d_nhwc_dsp.arm_cpu") + + if __name__ == "__main__": tvm.testing.main()