Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[SME][TOPI] Add conv2d NHWC SME fp32 schedule (#17003)" #17038

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Revert "[SME][TOPI] Add conv2d NHWC SME fp32 schedule (#17003)"
This reverts commit cab54e0.
tqchen authored May 28, 2024

Verified

This commit was signed with the committer’s verified signature. The key has expired.
aitbw Angel Perez
commit b71a9a3827d81ac17da5f5bc608583f1a02bd0d8
15 changes: 0 additions & 15 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
@@ -253,18 +253,6 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
)
# Non-quantized cases
if is_aarch64 and data.dtype in ["float32", "float16"]:
if (
target.features.has_sme
and data.dtype in ["float32"]
and kernel.dtype in ["float32"]
and out_type.dtype in ["float32"]
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_hybrid_SME),
lambda: None,
name="conv2d_NHWC_hybrid_SME.arm_cpu",
plevel=12,
)
if target.features.has_sve:
# This strategy is currently suboptimal because of LLVM's limited support
# for scalable vector alias analysis, which causes redundant loads / stores
@@ -818,9 +806,6 @@ def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool:
if matmul_block and sch.get(matmul_block).annotations.get("schedule_type", "") == "sme":
topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
return True
elif has_block(sch, "conv2d_gemm_output"):
topi.arm_cpu.schedule_conv2d_NHWC_hybrid_TIR(sch)
return True

# Fallback to TE schedule for operators we have not written a special TIR schedule for
return False
7 changes: 0 additions & 7 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
@@ -1071,13 +1071,6 @@ def _has_cpu_feat(features):
)


requires_aarch64_sme = Feature(
"arm_sme",
"AArch64 SME",
run_time_check=lambda: _has_cpu_feat("sme"),
)


requires_x86_vnni = Feature(
"x86_vnni",
"x86 VNNI Extensions",
18 changes: 4 additions & 14 deletions python/tvm/topi/arm_cpu/arm_utils.py
Original file line number Diff line number Diff line change
@@ -22,7 +22,7 @@
from tvm.tir.expr import PrimExpr


def get_tiling_A(interleave_A, in_dtype, use_sme=False):
def get_tiling_A(interleave_A, in_dtype):
"""Compute the tiling information for matrix A in C=A*B,
which corresponds to the im2col-transformed input matrix.

@@ -42,8 +42,6 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False):
determines if A is expected to be interleaved
in_dtype : str
input datatype
use_sme : bool
determines if SME operations on scalable vectors are expected

Returns
----------
@@ -67,11 +65,8 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False):
# tile size should be 4x16
tile_M = 4
tile_K = 16
elif use_sme:
tile_M = 2 * 4 * tvm.tir.vscale()
tile_K = 2 * 4 * tvm.tir.vscale()
else:
# In non-SME, non-quantized cases, A is not interleaved.
# In non-quantized cases, A is not interleaved.
# We are loading 4 rows from A.
# Each row will contain 4 elements, along the dimension of reduction
tile_M = 4
@@ -80,7 +75,7 @@ def get_tiling_A(interleave_A, in_dtype, use_sme=False):
return tile_M, tile_K


def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False, use_sme=False):
def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False):
"""Compute the tiling information for matrix B', where B'
is the tiled, interleaved (and transposed) version of matrix B in C=A*B.

@@ -102,8 +97,6 @@ def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False,
input datatype
use_scalable_vectors : bool
determines if operations on scalable vectors are expected
use_sme : bool
determines if SME operations on scalable vectors are expected


Returns
@@ -138,10 +131,7 @@ def get_tiling_B_transformed(interleave_A, in_dtype, use_scalable_vectors=False,
# we load 4 rows of B' (i.e., 4 columns of B). Each of them will contain 16 elements
tile_N = 4
tile_K = 16
elif use_sme:
tile_N = 2 * 4 * tvm.tir.vscale()
tile_K = 2 * 4 * tvm.tir.vscale()
# In non-SME, non-quantized cases, A is not interleaved.
# In non-quantized cases, A is not interleaved.
elif use_scalable_vectors:
if in_dtype == "float16":
# Each load from B' contains 32 * vscale elements (i.e. 32 * vscale columns from B)
238 changes: 3 additions & 235 deletions python/tvm/topi/arm_cpu/conv2d.py
Original file line number Diff line number Diff line change
@@ -21,15 +21,13 @@
import tvm
from tvm import te
from tvm import autotvm
from tvm.script import tir as T
import tvm.contrib.nnpack
from tvm.tir.schedule.analysis import has_block

from ..utils import traverse_inline, get_const_tuple
from .. import nn
from ..nn.utils import get_const_int, get_pad_tuple
from ..nn.winograd_util import winograd_transform_matrices
from .arm_utils import get_tiling_A, get_tiling_B_transformed
from .arm_utils import get_tiling_B_transformed
from .conv2d_spatial_pack import (
conv2d_spatial_pack_nchw,
conv2d_spatial_pack_nhwc,
@@ -529,16 +527,13 @@ def compute_conv2d_NHWC(
out_dtype,
interleave_A,
use_scalable_vectors=False,
use_sme=False,
):
"""Compute definition for conv2d NHWC"""
N, IH, IW, IC = get_const_tuple(data.shape)
KH, KW, _, OC = get_const_tuple(kernel.shape)
tile_N, tile_K = get_tiling_B_transformed(
interleave_A, data.dtype, use_scalable_vectors, use_sme
)
tile_N, tile_K = get_tiling_B_transformed(interleave_A, data.dtype, use_scalable_vectors)

kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors, use_sme)
kernel = nn.conv2d_gemm_weight_transform(kernel, tile_N, tile_K, use_scalable_vectors)
return compute_conv2d_gemm_without_weight_transform(
cfg,
data,
@@ -551,7 +546,6 @@ def compute_conv2d_NHWC(
OC,
interleave_A,
use_scalable_vectors,
use_sme,
)


@@ -661,229 +655,3 @@ def compute_conv2d_NHWC_hybrid_SVE(cfg, data, kernel, strides, padding, dilation
def schedule_conv2d_NHWC_hybrid_SVE(cfg, outs):
"""Interface for hybrid schedule_conv2d_NHWC_hybrid_SVE"""
return schedule_conv2d_NHWC(cfg, outs, False)


@autotvm.register_topi_compute("conv2d_NHWC_hybrid_SME.arm_cpu")
def compute_conv2d_NHWC_hybrid_SME(cfg, data, kernel, strides, padding, dilation, out_dtype):
"""Interface for hybrid compute_conv2d_NHWC_hybrid_SME"""
return compute_conv2d_NHWC(
cfg,
data,
kernel,
strides,
padding,
dilation,
out_dtype,
False,
True,
True,
)


def schedule_conv2d_NHWC_hybrid_TIR(sch: tvm.tir.Schedule):
"""
Perform TIR scheduling for conv2d NHWC.
"""
# Get ordered buffer list
primfunc = sch.mod["main"]
buffer_names = primfunc.params
buffer_list = [primfunc.buffer_map[buf] for buf in buffer_names]
dtype = buffer_list[0].dtype

# Determine PrimFunc blocks
block_list = [
"data_pad",
"data_im2col",
"T_reshape",
"A_padded_K",
"A_padded_M",
"weight_flatten",
"C",
"conv2d_gemm_output",
]
func_blocks = {}
for block in block_list:
func_blocks[block] = sch.get_block(block) if has_block(sch, block) else None

gemm_block = func_blocks["C"]
b, m, n, k = sch.get_loops(gemm_block)

# Get tiling information
use_scalable_vectors = sch.get(func_blocks["conv2d_gemm_output"]).annotations[
"use_scalable_vectors"
]
use_sme = sch.get(func_blocks["conv2d_gemm_output"]).annotations["use_sme"]
M_padded = sch.get(m).extent
N_padded = sch.get(n).extent
K_padded = sch.get(k).extent
tile_M, tile_K = get_tiling_A(False, dtype, use_sme)
tile_N, _ = get_tiling_B_transformed(False, dtype, use_scalable_vectors, use_sme)
tile_M = T.cast(tile_M, M_padded.dtype)
tile_N = T.cast(tile_N, N_padded.dtype)
tile_K = T.cast(tile_K, K_padded.dtype)

# GeMM
# Compute each tile_M x tile_N tile
# By summing up K outer products
if use_sme:
# pylint: disable=import-outside-toplevel
from tvm.topi.arm_cpu.pstate_attributes import SMEAttributes
from tvm.tir.tensor_intrin.arm_cpu import (
ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE,
ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA,
ARM_SME_INIT,
get_sme_gemm_interleaved_mopa_2svlx2svl_intrin,
)

# Interleave the padded im2col matrix utilizing the matrix tile
interleave_t_A_block = sch.cache_read(gemm_block, 0, "global")
sch.transform_layout(interleave_t_A_block, ("write", 0), lambda b, m, k: (b, k, m))
b, m, k = sch.get_loops(interleave_t_A_block)
mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True)
ko, ki = sch.split(k, factors=(None, tile_K), disable_predication=True)
sch.parallel(b)
sch.reorder(b, ko, mo, ki, mi)
sch.tensorize(ki, ARM_SME_2SVLx2SVL_TRANSPOSE_INTERLEAVE)

# Split and reorder the loops of the GeMM for tensorization
b, m, n, k = sch.get_loops(gemm_block)
mo, mi = sch.split(m, factors=(None, tile_M), disable_predication=True)
no, ni = sch.split(n, factors=(None, tile_N), disable_predication=True)
sch.parallel(b)
sch.reorder(b, mo, no, mi, ni, k)

# Tensorize the GeMM output matrix initialization to zero
init_block = sch.decompose_reduction(gemm_block, mi)
sch.tensorize(sch.get_loops(init_block)[-2], ARM_SME_INIT)

# Tensorize the GeMM update
sme_gemm_interleaved_intrin_name = ARM_SME_2SVLx2SVL_GEMM_INTERLEAVED_MOPA + f"_{K_padded}"
tvm.tir.TensorIntrin.register(
sme_gemm_interleaved_intrin_name,
*get_sme_gemm_interleaved_mopa_2svlx2svl_intrin(K_padded),
override=True,
)
sch.tensorize(mi, sme_gemm_interleaved_intrin_name)

# Add pstate annotations
root_block = sch.get_block("root")
sch.annotate(
root_block, SMEAttributes.STREAMING_MODE, SMEAttributes.StreamingModeValues.ENABLED
)
sch.annotate(root_block, SMEAttributes.ZA_STORAGE, SMEAttributes.ZAStorageValues.NEW)
elif use_scalable_vectors:
mo, mi = sch.split(m, [None, tile_M])
no, ni = sch.split(n, [None, tile_N], disable_predication=True)
ko, ki = sch.split(k, [None, tile_K])
b_mo_fused = sch.fuse(b, mo)
sch.parallel(b_mo_fused)
sch.reorder(
b_mo_fused,
no,
ko,
ki,
mi,
ni,
)
sch.vectorize(ni)
sch.unroll(mi)

# GeMM - Init
# Initialise an entire GeMM tile at once
sch.decompose_reduction(gemm_block, ko)
else:
mo, mi = sch.split(m, [None, tile_M])
no, ni = sch.split(n, [None, tile_N])
ko, ki = sch.split(k, [None, tile_K])
ni_outer, ni_inner = sch.split(ni, [4, None])
b_mo_fused = sch.fuse(b, mo)
sch.parallel(b_mo_fused)
sch.reorder(
b_mo_fused,
no,
ko,
ki,
ni_outer,
mi,
ni_inner,
)
sch.vectorize(ni_inner)
sch.unroll(mi)
sch.unroll(ni_outer)

# GeMM - Init
# Initialise an entire GeMM tile at once
sch.decompose_reduction(gemm_block, ko)

# Input padding
if func_blocks["data_pad"]:
input_padding_block = func_blocks["data_pad"]
b, h, w, ic = sch.get_loops(input_padding_block)
b_h_fused = sch.fuse(b, h)
sch.parallel(b_h_fused)

# Im2col + padding to tile size
# Computed outside GeMM
if func_blocks["data_im2col"]:
im2col_block = func_blocks["data_im2col"]
b1, m1, k1 = sch.get_loops(im2col_block)
b_m_fused_1 = sch.fuse(b1, m1)
if func_blocks["A_padded_K"]:
im2col_pad_K_block = func_blocks["A_padded_K"]
b2, m2, k2 = sch.get_loops(im2col_pad_K_block)
b_m_fused_2 = sch.fuse(b2, m2)
sch.parallel(b_m_fused_2)
sch.compute_at(im2col_block, b_m_fused_2)
_, k1 = sch.get_loops(sch.get_block("data_im2col"))
elif func_blocks["A_padded_M"]:
im2col_pad_M_block = func_blocks["A_padded_M"]
b2, m2, k2 = sch.get_loops(im2col_pad_M_block)
b_m_fused_2 = sch.fuse(b2, m2)
sch.parallel(b_m_fused_1)
sch.parallel(b_m_fused_2)
else:
sch.parallel(b_m_fused_1)

K = sch.get(k1).extent.value
if K % 16 == 0:
split_factor = 16
elif K % 8 == 0:
split_factor = 8
else:
IC = buffer_list[0].shape[3]
split_factor = IC
k_outer, k_inner = sch.split(k1, [None, split_factor])
sch.vectorize(k_inner)
sch.unroll(k_outer)

# Reshape + padding to tile size
# Computed inside GeMM
elif func_blocks["T_reshape"]:
reshape_block = func_blocks["T_reshape"]
A_pad_block = func_blocks["A_padded_K"] if func_blocks["A_padded_K"] else None
A_pad_block = func_blocks["A_padded_M"] if func_blocks["A_padded_M"] else A_pad_block
if use_sme:
sch.compute_inline(reshape_block)
elif A_pad_block:
sch.compute_inline(reshape_block)
b, m, k = sch.get_loops(A_pad_block)
_, k_inner = sch.split(k, [None, tile_N])
sch.vectorize(k_inner)
sch.compute_at(A_pad_block, mi)
else:
sch.compute_at(reshape_block, mi)

# Weight flattening
if func_blocks["weight_flatten"]:
weight_flatten_block = func_blocks["weight_flatten"]
sch.compute_inline(weight_flatten_block)

# Conv2d output block
output_block = func_blocks["conv2d_gemm_output"]
n, h, w, c = sch.get_loops(output_block)
n_h_fused = sch.fuse(n, h)
_, inner = sch.split(c, [None, 4])
sch.vectorize(inner)
sch.parallel(n_h_fused)

return sch
Loading