Skip to content

Commit

Permalink
[TIR] [Hexagon] Add vdmpy intrinsic and transform_layout for tests (#…
Browse files Browse the repository at this point in the history
…13557)

[TIR] Add vdmpy intrinsic and transform_layout for tests

This patch adds the vdmpy hexagon intrinsic and a sample tensorization
test for the same.

This patch modifies the test to use transform_layout instead of a packed
tensor in the compute to make it obvious that this example is just
matmul with a different data layout for one of the inputs
  • Loading branch information
quic-sanirudh authored Dec 6, 2022
1 parent 6574e16 commit 8d31b25
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 11 deletions.
46 changes: 46 additions & 0 deletions python/tvm/tir/tensor_intrin/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,48 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non
return dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy


def generate_dot_product_32x2_i16i16i32(mem_scope="global"):
@T.prim_func
def dot_product_32x2_i16i16i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:2], B[0:32, 0:2])
T.writes(C[0:32])
for i in T.serial(0, 32):
for k in T.serial(0, 2):
with T.block("update"):
vi, vk = T.axis.remap("SR", [i, k])
C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32")

@T.prim_func
def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope)
B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope)
C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope)
with T.block("root"):
T.reads(C[0:32], A[0:2], B[0:32, 0:2])
T.writes(C[0:32])

A_i16x2 = A.vload([0], "int16x2")
A_i32 = T.reinterpret(A_i16x2, dtype="int32")

B_i16x64 = B.vload([0, 0], dtype="int16x64")
B_i32x32 = T.reinterpret(B_i16x64, dtype="int32x32")

C[T.ramp(T.int32(0), 1, 32)] = T.call_llvm_pure_intrin(
T.llvm_lookup_intrinsic_id("llvm.hexagon.V6.vdmpyhvsat.acc.128B"),
T.uint32(3),
C[T.ramp(T.int32(0), 1, 32)],
T.Broadcast(A_i32, 32),
B_i32x32,
dtype="int32x32",
)

return dot_product_32x2_i16i16i32_desc, dot_product_32x2_i16i16i32_vdmpy


VRMPY_u8u8i32_INTRIN = "dot_32x4_u8u8i32_vrmpy"

TensorIntrin.register(VRMPY_u8u8i32_INTRIN, *generate_dot_product_32x4_u8u8i32())
Expand All @@ -112,6 +154,10 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non

TensorIntrin.register(VRMPY_u8i8i32_INTRIN, *generate_dot_product_32x4_u8i8i32())

VDMPY_i16i16i32_INTRIN = "dot_product_32x2_i16i16i32_vdmpy"

TensorIntrin.register(VDMPY_i16i16i32_INTRIN, *generate_dot_product_32x2_i16i16i32())

VRMPY_u8u8i32_VTCM_INTRIN = "dot_32x4_u8u8i32_vtcm_vrmpy"
TensorIntrin.register(VRMPY_u8u8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8u8i32("global.vtcm"))

Expand Down
42 changes: 31 additions & 11 deletions tests/python/unittest/test_tir_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
)
from tvm.tir.tensor_intrin.rocm import AMDGPU_SDOT4_INTRIN
from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN
from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN, VDMPY_i16i16i32_INTRIN

# fmt: off
# pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks
Expand Down Expand Up @@ -540,33 +540,31 @@ def test_tensorize_with_annotation():
verify_trace_roundtrip(sch=s, mod=func)


def get_matmul_packed(m, n, k, lhs_type, int32_lanes, rhs_dtype="int8"):
def get_matmul_packed(m, n, k, lhs_type, rhs_dtype="int8"):
X = te.placeholder((m, k), name="X", dtype=lhs_type)
packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype=rhs_dtype)
W = te.placeholder((n, k), name="W", dtype=rhs_dtype)

ak = te.reduce_axis((0, k), name="k")
matmul = te.compute(
(m, n),
lambda i, j: te.sum(
X[i, ak].astype("int32")
* packed_W[
tvm.tir.indexdiv(j, int32_lanes), tvm.tir.indexdiv(ak, 4), j % int32_lanes, ak % 4
].astype("int32"),
X[i, ak].astype("int32") * W[j, ak].astype("int32"),
axis=ak,
),
name="compute",
)

return te.create_prim_func([X, packed_W, matmul])
return te.create_prim_func([X, W, matmul])


def test_tensorize_vnni():
m, n, k = 128, 128, 128

func = get_matmul_packed(m, n, k, "uint8", 16)
func = get_matmul_packed(m, n, k, "uint8")

sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
sch.transform_layout(block, "W", lambda i, j: [i//16, j//4, i%16, j%4])
_, j, k = sch.get_loops(block)

_, ji = sch.split(j, factors=[None, 16])
Expand All @@ -582,11 +580,12 @@ def test_tensorize_vnni():
def test_tensorize_arm_dot():
m, n, k = 128, 128, 128

func = get_matmul_packed(m, n, k, "int8", 4)
func = get_matmul_packed(m, n, k, "int8")

for intrin in [ARM_DOT_4x4_i8_SDOT_INTRIN, ARM_DOT_4x4_i8_NEON_INTRIN]:
sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
sch.transform_layout(block, "W", lambda i, j: [i//4, j//4, i%4, j%4])
_, j, k = sch.get_loops(block)

_, ji = sch.split(j, factors=[None, 4])
Expand All @@ -602,10 +601,11 @@ def test_tensorize_arm_dot():
def test_tensorize_vrmpy():
m, n, k = 128, 128, 128

func = get_matmul_packed(m, n, k, "uint8", 32, "uint8")
func = get_matmul_packed(m, n, k, "uint8", "uint8")

sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
sch.transform_layout(block, "W", lambda i, j: [i//32, j//4, i%32, j%4])
_, j, k = sch.get_loops(block)

_, ji = sch.split(j, factors=[None, 32])
Expand All @@ -618,6 +618,26 @@ def test_tensorize_vrmpy():
verify_trace_roundtrip(sch=sch, mod=func)


def test_tensorize_vdmpy():
m, n, k = 128, 128, 128

func = get_matmul_packed(m, n, k, "int16", "int16")

sch = tir.Schedule(func, debug_mask="all")
block = sch.get_block("compute")
sch.transform_layout(block, "W", lambda i, j: [i//32, j//2, i%32, j%2])
_, j, k = sch.get_loops(block)

_, ji = sch.split(j, factors=[None, 32])
ko, ki = sch.split(k, factors=[None, 2])
sch.reorder(ko, ji, ki)

sch.decompose_reduction(block, ko)
sch.tensorize(ji, VDMPY_i16i16i32_INTRIN)

verify_trace_roundtrip(sch=sch, mod=func)


def test_tensorize_dpa4():
m, n, k = 128, 128, 128

Expand Down

0 comments on commit 8d31b25

Please sign in to comment.