Skip to content

Commit

Permalink
[Hexagon] Fix TIR vrmpy tensorization (#13404)
Browse files Browse the repository at this point in the history
[Hexagon] Fix vrmpy tensorization
  • Loading branch information
masahi authored Nov 16, 2022
1 parent 271ad43 commit b4d4b82
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
4 changes: 0 additions & 4 deletions python/tvm/tir/tensor_intrin/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def dot_product_32x4_u8u8i32_desc(
for i in T.serial(0, 32):
for k in T.serial(0, 4):
with T.block("update"):
with T.init():
C[i] = T.int32(0)
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")

Expand Down Expand Up @@ -76,8 +74,6 @@ def dot_product_32x4_u8i8i32_desc(
for i in T.serial(0, 32):
for k in T.serial(0, 4):
with T.block("update"):
with T.init():
C[i] = T.int32(0)
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")

Expand Down
26 changes: 23 additions & 3 deletions tests/python/unittest/test_tir_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)
from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN

# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
Expand Down Expand Up @@ -539,17 +540,17 @@ def test_tensorize_with_annotation():
verify_trace_roundtrip(sch=s, mod=func)


def get_matmul_packed(m, n, k, lhs_type, int32_lanes):
def get_matmul_packed(m, n, k, lhs_type, int32_lanes, rhs_dtype="int8"):
X = te.placeholder((m, k), name="X", dtype=lhs_type)
packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype="int8")
packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype=rhs_dtype)

ak = te.reduce_axis((0, k), name="k")
matmul = te.compute(
(m, n),
lambda i, j: te.sum(
X[i, ak].astype("int32")
* packed_W[
tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4
tvm.tir.indexdiv(j, int32_lanes), tvm.tir.indexdiv(ak, 4), j % int32_lanes, ak % 4
].astype("int32"),
axis=ak,
),
Expand Down Expand Up @@ -598,6 +599,25 @@ def test_tensorize_arm_dot():
verify_trace_roundtrip(sch=sch, mod=func)


def test_tensorize_vrmpy():
m, n, k = 128, 128, 128

func = get_matmul_packed(m, n, k, "uint8", 32, "uint8")

sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
_, j, k = sch.get_loops(block)

_, ji = sch.split(j, factors=[None, 32])
ko, ki = sch.split(k, factors=[None, 4])
sch.reorder(ko, ji, ki)

sch.decompose_reduction(block, ko)
sch.tensorize(ji, VRMPY_u8u8i32_INTRIN)

verify_trace_roundtrip(sch=sch, mod=func)


def test_tensorize_dpa4():
m, n, k = 128, 128, 128

Expand Down

0 comments on commit b4d4b82

Please sign in to comment.