Skip to content

Commit

Permalink
[TOPI] Reduce code redundancy in conv2d weights transformation
Browse files Browse the repository at this point in the history
Refactored out a piece of common functionality from the `conv2d_gemm_weight_transform` and `interleave_transpose_weights` functions, which has previously led to bugs stemming from changes made to only one but not the other, like in apache#15584.
Determining the necessary padding for the interleaved and transposed weights matrix has now been separated into a new utility function, allowing future changes to be reflected in both callers.
  • Loading branch information
Anndrey24 committed Nov 6, 2023
1 parent ffa0033 commit 8b08d01
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 30 deletions.
38 changes: 38 additions & 0 deletions python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,41 @@ def get_tiling_B_interleaved_t(interleave_A):
tile_cols_B = 16

return tile_rows_B, tile_cols_B


def get_conv2d_weights_padding(N, K, tile_rows, tile_cols):
"""Compute the necessary padding for matrix B', where B'
is the transposed and interleaved version of matrix B in C=A*B.
Parameters
----------
N : int
Number of rows in B' = OC
K : int
Number of columns in B' = KW * KH * IC
tile_rows : int
tile rows of B'
tile_cols : int
tile columns of B'
Returns
----------
pad_N : padding for N axis
pad_K : padding for K axis
"""
pad_N = 0
pad_K = 0

if N % tile_rows != 0:
pad_N = tile_rows - (N % tile_rows)

# Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such
# that the columns is multiple of 4
column_multiplier = 4
tile_cols_multiplied = tile_cols * column_multiplier
K_misalignment = K % tile_cols_multiplied

if K_misalignment != 0:
pad_K = tile_cols_multiplied - K_misalignment

return pad_N, pad_K
18 changes: 2 additions & 16 deletions python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ..x86.conv2d import _get_default_config as _get_x86_default_config
from ..x86.conv2d_int8 import _get_default_config_int8
from .conv2d_int8 import is_int8_hw_support
from .arm_utils import get_tiling_B_interleaved_t
from .arm_utils import get_tiling_B_interleaved_t, get_conv2d_weights_padding
from ..generic.conv2d import conv2d_alter_int8_common
from .mprofile.dsp.micro_kernel.common import num_simd_lanes_per_word

Expand Down Expand Up @@ -72,21 +72,7 @@ def interleave_transpose_weights(inputs, data, kernel, interleave_A):

# Get tiling information for the interleaved transposed version of B
tile_rows_B, tile_cols_B = get_tiling_B_interleaved_t(interleave_A)

pad_K = 0
pad_N = 0

if N % tile_rows_B != 0:
pad_N = tile_rows_B - (N % tile_rows_B)

# Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such
# that the columns is multiple of 4
column_multiplier = 4
tile_cols_multiplied = tile_cols_B * column_multiplier
K_misalignment = K % tile_cols_multiplied

if K_misalignment != 0:
pad_K = tile_cols_multiplied - K_misalignment
pad_N, pad_K = get_conv2d_weights_padding(N, K, tile_rows_B, tile_cols_B)

N_padded = N + pad_N
K_padded = K + pad_K
Expand Down
15 changes: 1 addition & 14 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,20 +617,7 @@ def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
(K, N), lambda x, y: kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y], "weight_flatten"
)

pad_K = 0
pad_N = 0

if N % tile_rows != 0:
pad_N = tile_rows - (N % tile_rows)

# Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such
# that the columns is multiple of 4
column_multiplier = 4
tile_cols_multiplied = tile_cols * column_multiplier
K_misalignment = K % tile_cols_multiplied

if K_misalignment != 0:
pad_K = tile_cols_multiplied - K_misalignment
pad_N, pad_K = tvm.topi.arm_cpu.arm_utils.get_conv2d_weights_padding(N, K, tile_rows, tile_cols)

N_padded = N + pad_N
K_padded = K + pad_K
Expand Down

0 comments on commit 8b08d01

Please sign in to comment.