diff --git a/python/tvm/script/tir/__init__.pyi b/python/tvm/script/tir/__init__.pyi index 1be249bc9e89..3eb383ed9974 100644 --- a/python/tvm/script/tir/__init__.pyi +++ b/python/tvm/script/tir/__init__.pyi @@ -124,6 +124,8 @@ def Select(cond: PrimExpr, if_body: PrimExpr, else_body: PrimExpr) -> PrimExpr: def if_then_else(cond: PrimExpr, t: PrimExpr, f: PrimExpr, dtype: str) -> PrimExpr: ... def evaluate(value: PrimExpr) -> None: ... def reinterpret(value: PrimExpr, dtype: str) -> PrimExpr: ... +def vectorlow(value: PrimExpr, dtype: str) -> PrimExpr: ... +def vectorhigh(value: PrimExpr, dtype: str) -> PrimExpr: ... def store( var: Var, index: PrimExpr, value: PrimExpr, predicate: Union[PrimExpr, builtins.bool] = True ) -> None: ... @@ -143,7 +145,7 @@ def preflattened_buffer( ) -> Buffer: ... """ -Intrinsics - tvm builtin +Intrinsics - tvm builtin """ def tvm_thread_allreduce( diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 0148bd0b4243..3d0fb407ef3f 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -25,10 +25,9 @@ import tvm.tir from tvm.runtime import Object, String -from tvm import te from tvm.target import Target from tvm.ir import Span -from tvm.tir import IntImm, IterVar +from tvm.tir import IntImm, IterVar, Var from .node import BufferSlice from .utils import buffer_slice_to_region @@ -800,7 +799,7 @@ def var(dtype, span): self.context.report_error( f"VarDef expected assign to only one var, but got {names}", span ) - v = te.var(names[0], dtype, span=span) + v = Var(names[0], dtype, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(var, def_symbol=True) @@ -821,7 +820,7 @@ def buffer_var(dtype, storage_scope, span): f"VarDef expected assign to only one var, but got {names}", span ) ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) - v = te.var(names[0], ptr_type, span=span) + v = Var(names[0], ptr_type, span=span) self.context.update_symbol(v.name, v, self.node) super().__init__(buffer_var, def_symbol=True) @@ -841,7 +840,7 @@ def env_thread(env_name, span): self.context.report_error( f"VarDef expected assign to only one var, but got {names}", span ) - v = te.var(names[0], span=span) + v = Var(names[0], dtype="int32", span=span) self.context.func_var_env_dict[v] = env_name self.context.update_symbol(v.name, v, self.node) diff --git a/python/tvm/tir/tensor_intrin/__init__.py b/python/tvm/tir/tensor_intrin/__init__.py new file mode 100644 index 000000000000..62159851b3d4 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Intrinsics for tensorization.""" +from .x86 import * +from .arm_cpu import * diff --git a/python/tvm/tir/tensor_intrin/arm_cpu.py b/python/tvm/tir/tensor_intrin/arm_cpu.py new file mode 100644 index 000000000000..6e16b1f767f3 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/arm_cpu.py @@ -0,0 +1,142 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring +"""Intrinsics for ARM tensorization.""" +from tvm.script import tir as T +from .. import TensorIntrin + + +# TODO(masahi): Parametrize the TVMScript description of dot product by +# shape and dtype, and share the common description with x86. + + +@T.prim_func +def dot_product_4x4_i8i8i32_desc( + A: T.Buffer((4,), "int8", offset_factor=1), + B: T.Buffer((4, 4), "int8", offset_factor=1), + C: T.Buffer((4,), "int32", offset_factor=1), +) -> None: + with T.block("root"): + T.reads(C[0:4], A[0:4], B[0:4, 0:4]) + T.writes(C[0:4]) + for i in T.serial(0, 4): + with T.init(): + C[i] = T.int32(0) + for k in T.serial(0, 4): + with T.block("update"): + vi, vk = T.axis.remap("SR", [i, k]) + C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") + + +@T.prim_func +def dot_product_4x4_i8i8i32_neon( + A: T.Buffer((4,), "int8", offset_factor=1), + B: T.Buffer((4, 4), "int8", offset_factor=1), + C: T.Buffer((4,), "int32", offset_factor=1), +) -> None: + with T.block("root"): + T.reads(C[0:4], A[0:4], B[0:4, 0:4]) + T.writes(C[0:4]) + + A_int8 = A.vload([0], "int8x4") + re_int32 = T.reinterpret(A_int8, dtype="int32") + vec_ai32 = T.broadcast(re_int32, 2) + vec_a = T.reinterpret(vec_ai32, dtype="int8x8") + + vec_b = B.vload([0, 0], dtype="int8x16") + + # TODO(masahi): Remove duplication when inlined function call is supported + vec_b_low = T.vectorlow(vec_b, dtype="int8x8") + + multiply_low = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), + T.uint32(2), + vec_a, + vec_b_low, + dtype="int16x8", + ) + + pairwise_reduction_low = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), + T.uint32(1), + multiply_low, + dtype="int32x4", + ) + + vec_b_high = T.vectorhigh(vec_b, dtype="int8x8") + + multiply_high = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.smull.v8i16"), + T.uint32(2), + vec_a, + vec_b_high, + dtype="int16x8", + ) + + pairwise_reduction_high = T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.saddlp.v4i32.v8i16"), + T.uint32(1), + multiply_high, + dtype="int32x4", + ) + + C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.addp.v4i32"), + T.uint32(2), + pairwise_reduction_low, + pairwise_reduction_high, + dtype="int32x4", + ) + + +@T.prim_func +def dot_product_4x4_i8i8i32_sdot( + A: T.Buffer((4,), "int8", offset_factor=1), + B: T.Buffer((4, 4), "int8", offset_factor=1), + C: T.Buffer((4,), "int32", offset_factor=1), +) -> None: + with T.block("root"): + T.reads(C[0:4], A[0:4], B[0:4, 0:4]) + T.writes(C[0:4]) + + A_i8x4 = A.vload([0], "int8x4") + A_i32 = T.reinterpret(A_i8x4, dtype="int32") + vec_ai32 = T.broadcast(A_i32, 4) + vec_a = T.reinterpret(vec_ai32, dtype="int8x16") + + vec_b = B.vload([0, 0], dtype="int8x16") + + C[T.ramp(T.int32(0), 1, 4)] += T.call_llvm_pure_intrin( + T.llvm_lookup_intrinsic_id("llvm.aarch64.neon.sdot.v4i32.v16i8"), + T.uint32(3), + T.int32x4(0), + vec_a, + vec_b, + dtype="int32x4", + ) + + +ARM_DOT_4x4_i8_NEON_INTRIN = "dot_4x4_i8i8s32_neon" +ARM_DOT_4x4_i8_SDOT_INTRIN = "dot_4x4_i8i8s32_sdot" + +TensorIntrin.register( + ARM_DOT_4x4_i8_NEON_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_neon +) + +TensorIntrin.register( + ARM_DOT_4x4_i8_SDOT_INTRIN, dot_product_4x4_i8i8i32_desc, dot_product_4x4_i8i8i32_sdot +) diff --git a/python/tvm/tir/tensor_intrin/x86.py b/python/tvm/tir/tensor_intrin/x86.py new file mode 100644 index 000000000000..ee57c9aa4750 --- /dev/null +++ b/python/tvm/tir/tensor_intrin/x86.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,missing-function-docstring +"""Intrinsics for x86 tensorization.""" +from tvm.script import tir as T +from .. import TensorIntrin + + +# Tensorized intrinsic description and VNNI-specific implementation. +# Equivalent to the ones in topi/x86/tensor_intrin.py + + +@T.prim_func +def dot_product_16x4_u8i8i32_desc( + A: T.Buffer((4,), "uint8", offset_factor=1), + B: T.Buffer((16, 4), "int8", offset_factor=1), + C: T.Buffer((16,), "int32", offset_factor=1), +) -> None: + with T.block("root"): + T.reads(C[0:16], A[0:4], B[0:16, 0:4]) + T.writes(C[0:16]) + for i in T.serial(0, 16): + with T.init(): + C[i] = T.int32(0) + for k in T.serial(0, 4): + with T.block("update"): + vi, vk = T.axis.remap("SR", [i, k]) + C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") + + +@T.prim_func +def dot_product_16x4_u8i8i32_vnni( + A: T.Buffer((4,), "uint8", offset_factor=1), + B: T.Buffer((16, 4), "int8", offset_factor=1), + C: T.Buffer((16,), "int32", offset_factor=1), +) -> None: + with T.block("root"): + T.reads(C[0:16], A[0:4], B[0:16, 0:4]) + T.writes(C[0:16]) + + A_u8x4 = A.vload([0], "uint8x4") + A_i32 = T.reinterpret(A_u8x4, dtype="int32") + + B_i8x64 = B.vload([0, 0], dtype="int8x64") + B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") + + C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update += + T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), + T.uint32(0), + T.int32x16(0), + T.broadcast(A_i32, 16), + B_i32x16, + dtype="int32x16", + ) + + +VNNI_DOT_16x4_INTRIN = "dot_16x4_vnni" + +TensorIntrin.register( + VNNI_DOT_16x4_INTRIN, dot_product_16x4_u8i8i32_desc, dot_product_16x4_u8i8i32_vnni +) diff --git a/tests/python/unittest/test_meta_schedule_tune_relay.py b/tests/python/unittest/test_meta_schedule_tune_relay.py index af25d2a6f39e..64b8795c5eaf 100644 --- a/tests/python/unittest/test_meta_schedule_tune_relay.py +++ b/tests/python/unittest/test_meta_schedule_tune_relay.py @@ -38,6 +38,8 @@ from tvm.target.target import Target from tvm.tir.schedule import BlockRV, Schedule from tvm.tir.schedule.trace import Trace +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN + logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) @@ -328,57 +330,6 @@ def get_output(data, lib): assert np.allclose(actual_output, expected_output, rtol=1e-4, atol=2e-4) -# Tensorized intrinsic description and VNNI-specific implementation. -# Equivalent to the ones in topi/x86/tensor_intrin.py - - -@T.prim_func -def dot_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) - C = T.match_buffer(c, (16,), "int32", offset_factor=1) - - with T.block("root"): - T.reads(C[0:16], A[0:4], B[0:16, 0:4]) - T.writes(C[0:16]) - for i in T.serial(0, 16): - with T.init(): - C[i] = T.int32(0) - for k in T.serial(0, 4): - with T.block("update"): - vi, vk = T.axis.remap("SR", [i, k]) - C[vi] = C[vi] + T.cast(A[vk], "int32") * T.cast(B[vi, vk], "int32") - - -@T.prim_func -def dot_product_vnni(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1) - B = T.match_buffer(b, (16, 4), "int8", offset_factor=1) - C = T.match_buffer(c, (16,), "int32", offset_factor=1) - - with T.block("root"): - T.reads(C[0:16], A[0:4], B[0:16, 0:4]) - T.writes(C[0:16]) - - A_u8x4 = A.vload([0], "uint8x4") - A_i32 = T.reinterpret(A_u8x4, dtype="int32") - - B_i8x64 = B.vload([0, 0], dtype="int8x64") - B_i32x16 = T.reinterpret(B_i8x64, dtype="int32x16") - - C[T.ramp(T.int32(0), 1, 16)] += T.call_llvm_pure_intrin( # Note: this is an update += - T.llvm_lookup_intrinsic_id("llvm.x86.avx512.vpdpbusd.512"), - T.uint32(0), - T.int32x16(0), - T.broadcast(A_i32, 16), - B_i32x16, - dtype="int32x16", - ) - - -VNNI_INTRIN = "dot_16x1x16_uint8_int8_int32_cascadelake" - - def schedule_dense(dense_block, M, do_tune, sch): """ Manually schedule a dense block, created from TE compute op via CreatePrimFunc, @@ -546,10 +497,6 @@ def schedule_fn(task, sch): @pytest.mark.skip("Requires cascadelake") def test_tune_relay_manual_tir_vnni(): - # Register a pair of an intrinsic description for 16x4 dot product, and its - # VNNI-specific implementation. - tir.TensorIntrin.register(VNNI_INTRIN, dot_product_desc, dot_product_vnni) - manual_tir_common(do_tune=False) """ diff --git a/tests/python/unittest/test_tir_schedule_tensorize.py b/tests/python/unittest/test_tir_schedule_tensorize.py index 5cef8d63587d..482d6f3db574 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize.py +++ b/tests/python/unittest/test_tir_schedule_tensorize.py @@ -19,9 +19,14 @@ import pytest import tvm import tvm.testing -from tvm import tir +from tvm import tir, te from tvm.script import tir as T from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.tensor_intrin import ( + VNNI_DOT_16x4_INTRIN, + ARM_DOT_4x4_i8_NEON_INTRIN, + ARM_DOT_4x4_i8_SDOT_INTRIN, +) # fmt: off # pylint: disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks @@ -531,5 +536,64 @@ def test_tensorize_with_annotation(): verify_trace_roundtrip(sch=s, mod=func) +def get_matmul_packed(m, n, k, lhs_type, int32_lanes): + X = te.placeholder((m, k), name="X", dtype=lhs_type) + packed_W = te.placeholder((n // int32_lanes, k // 4, int32_lanes, 4), name="packedW", dtype="int8") + + ak = te.reduce_axis((0, k), name="k") + matmul = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype("int32") + * packed_W[ + tvm.tir.indexdiv(j, 16), tvm.tir.indexdiv(ak, 4), j % 16, ak % 4 + ].astype("int32"), + axis=ak, + ), + name="compute", + ) + + return te.create_prim_func([X, packed_W, matmul]) + + +def test_tensorize_vnni(): + m, n, k = 128, 128, 128 + + func = get_matmul_packed(m, n, k, "uint8", 16) + + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("compute") + _, j, k = sch.get_loops(block) + + _, ji = sch.split(j, factors=[None, 16]) + ko, ki = sch.split(k, factors=[None, 4]) + sch.reorder(ko, ji, ki) + + sch.decompose_reduction(block, ko) + sch.tensorize(ji, VNNI_DOT_16x4_INTRIN) + + verify_trace_roundtrip(sch=sch, mod=func) + + +def test_tensorize_arm_dot(): + m, n, k = 128, 128, 128 + + func = get_matmul_packed(m, n, k, "int8", 4) + + for intrin in [ARM_DOT_4x4_i8_SDOT_INTRIN, ARM_DOT_4x4_i8_NEON_INTRIN]: + sch = tir.Schedule(func, debug_mask="all") + block = sch.get_block("compute") + _, j, k = sch.get_loops(block) + + _, ji = sch.split(j, factors=[None, 4]) + ko, ki = sch.split(k, factors=[None, 4]) + sch.reorder(ko, ji, ki) + + sch.decompose_reduction(block, ko) + sch.tensorize(ji, intrin) + + verify_trace_roundtrip(sch=sch, mod=func) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))