Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix][Relay][Strategy] Enable compile time transformation of weights matrix for arm_cpu NHWC quantized conv2d #15584

Merged
merged 5 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,24 +468,29 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
layout = attrs.data_layout
data = inputs[0]
strategy = _op.OpStrategy()
is_aarch64 = target.features.is_aarch64
has_asimd = target.features.has_asimd
has_dot_prod = target.features.has_dotprod

interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform
native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform
if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d_gemm(native_compute),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform
),
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
)
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
if has_dot_prod:
strategy.add_implementation(
wrap_compute_conv2d_gemm(native_compute),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform
),
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
)
if is_aarch64 and has_asimd:
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
else:
raise RuntimeError(
f"Unsupported conv2d_NHWC_quantized_without_transform layout {layout}"
Expand Down
23 changes: 9 additions & 14 deletions python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,15 @@ def interleave_transpose_weights(inputs, data, kernel, interleave_A):

if N % tile_rows_B != 0:
pad_N = tile_rows_B - (N % tile_rows_B)
if K % tile_cols_B != 0:
pad_K = tile_cols_B - (K % tile_cols_B)

# Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such
# that the columns is multiple of 4
column_multiplier = 4
tile_cols_multiplied = tile_cols_B * column_multiplier
K_misalignment = K % tile_cols_multiplied

if K_misalignment != 0:
pad_K = tile_cols_multiplied - K_misalignment

N_padded = N + pad_N
K_padded = K + pad_K
Expand Down Expand Up @@ -434,12 +441,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)

if topi_tmpl == "conv2d_NHWC_quantized_interleaved.arm_cpu":
# TODO(masahi): This schedule can easily result in a tensorization error
# if used in the fallback mode
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu"
Expand All @@ -456,12 +457,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
inputs[0], new_kernel_expr, **new_attrs
)
if topi_tmpl == "conv2d_NHWC_quantized_native.arm_cpu":
# TODO(masahi): This schedule can easily result in a tensorization error
# if used in the fallback mode
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_native_without_transform.arm_cpu"
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1510,10 +1510,10 @@ bool Conv2DGemmWeightTransformRel(const Array<Type>& types, int num_inputs, cons
const auto K = weight->shape[0] * weight->shape[1] * weight->shape[2];
const auto N = weight->shape[3];

auto K_mod_k = indexmod(K, k);
auto K_mod_k = indexmod(K, k * 4);
auto N_mod_n = indexmod(N, n);

auto pad_K = tvm::if_then_else(K_mod_k != 0, k - K_mod_k, tir::make_zero(DataType::Int(32)));
auto pad_K = tvm::if_then_else(K_mod_k != 0, k * 4 - K_mod_k, tir::make_zero(DataType::Int(32)));
auto pad_N = tvm::if_then_else(N_mod_n != 0, n - N_mod_n, tir::make_zero(DataType::Int(32)));

const auto N_padded = N + pad_N;
Expand Down
20 changes: 17 additions & 3 deletions tests/python/relay/strategy/test_select_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import tvm
from tvm import relay
from tvm import te
from tvm.relay.testing import run_infer_type
from tvm.relay.testing import run_infer_type, run_opt_pass
import tvm.testing
from tvm import topi

Expand Down Expand Up @@ -63,12 +63,24 @@ def test_concatenate(target, expected_implementation):
("llvm -device=arm_cpu", "conv2d_nhwc_spatial_pack.arm_cpu"),
(
"llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon",
"conv2d_NHWC_quantized_interleaved.arm_cpu",
"conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
),
(
"llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon",
"conv2d_nhwc_spatial_pack.arm_cpu",
),
(
"llvm -device=arm_cpu -mtriple=aarch64-linux-gnu",
"conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
),
(
"llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod",
"conv2d_NHWC_quantized_native_without_transform.arm_cpu",
),
(
"llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+i8mm",
"conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
),
],
)
def test_int8_conv2d(target, expected_impl):
Expand All @@ -89,16 +101,18 @@ def test_int8_conv2d(target, expected_impl):
channels=channels,
data_layout=data_layout,
kernel_layout=kernel_layout,
out_dtype=dtype,
)
out = run_infer_type(out)

with target:
out = run_opt_pass(out, relay.transform.AlterOpLayout())
impl, _ = relay.backend.te_compiler.select_implementation(
out.op,
out.attrs,
[te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)],
out.checked_type,
target,
use_autotvm=False,
)

assert impl.name == expected_impl
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ def test_alter_layout_nhwc_int8_aarch64():
"""Check that AlterOplayout does not alter NHWC data layout."""
from tvm import autotvm

expected_workload_shape = (20, 42, 4, 16)
expected_workload_shape = (20, 44, 4, 16)

# We use Int8Fallback to disable the fallback flag
# and to test the new workload produced during the pass
Expand Down