From 7b7677fc757ad003aa85ad481f2a4bba6d77957a Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Fri, 8 Mar 2024 08:38:53 +0800 Subject: [PATCH] [TIR] Enhance and fix tensorize schedule for some case (#16560) * support tensorize with simplified and call expr * replace stmt simplifier with primfunc simplifier * lint fix * lint:remove white space * lint: remove white space * cpp lint fix * lint: resolve include * clang format lint fix --- src/tir/schedule/ir_comparator.cc | 24 ++++ src/tir/schedule/ir_comparator.h | 1 + .../schedule/primitive/blockize_tensorize.cc | 5 +- src/tir/transforms/simplify.cc | 8 ++ src/tir/transforms/simplify.h | 9 +- .../test_tir_schedule_tensorize.py | 118 ++++++++++++++++++ 6 files changed, 159 insertions(+), 6 deletions(-) diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 5353a051a60a..00e573eaf6e4 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -83,6 +83,30 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { return equal; } +bool TensorizeComparator::VisitExpr_(const CallNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + if (!rhs->op.same_as(op->op)) return false; + if (op->dtype.code() != rhs->dtype.code()) { + if (assert_mode_) { + std::ostringstream os; + os << "CallNode data type codes do not match: op->dtype.code()=" << op->dtype.code() + << " vs rhs->dtype.code()=" << rhs->dtype.code(); + EmitError(os.str()); + } + return false; + } + if (!CompareArray(op->args, rhs->args, &TensorizeComparator::VisitExpr)) { + if (assert_mode_) { + std::ostringstream os; + os << "CallNode iter_values do not match: op->iter_values=" << op->args + << " vs rhs->iter_values=" << rhs->args; + EmitError(os.str()); + } + return false; + } + return true; +} + bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { const auto* rhs = other.as(); if (!DefEqual(op->loop_var, rhs->loop_var)) { diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index debf0f946e28..f86dbd358391 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -46,6 +46,7 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override; bool VisitStmt(const Stmt& n, const Stmt& other) override; + bool VisitExpr_(const CallNode* op, const PrimExpr& other) override; bool VisitStmt_(const ForNode* op, const Stmt& other) override; bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override; bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index e8445a510147..c057a3d4fe72 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -20,6 +20,7 @@ #include +#include "../../transforms/simplify.h" #include "../ir_comparator.h" #include "../utils.h" @@ -755,7 +756,9 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int << GetRef(sref->stmt); throw; } - PrimFunc intrin_desc = intrin->desc; + + arith::Analyzer analyzer; + PrimFunc intrin_desc = Simplify(intrin->desc, &analyzer); PrimFunc intrin_impl = DeepCopy(intrin->impl); int index_dtype_bits = -1; diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 44d64df63d9f..f518c61bc676 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -21,6 +21,9 @@ * \file simplify.cc * \brief Statement simplifier based on analyzer */ + +#include "../../tir/transforms/simplify.h" + #include #include #include @@ -339,6 +342,11 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } // namespace arith namespace tir { + +PrimFunc Simplify(PrimFunc func, arith::Analyzer* analyzer) { + return arith::StmtSimplifier::Apply(std::move(func), analyzer); +} + namespace transform { Pass Simplify() { diff --git a/src/tir/transforms/simplify.h b/src/tir/transforms/simplify.h index 43afc5e48dcb..25c9dd5791d9 100644 --- a/src/tir/transforms/simplify.h +++ b/src/tir/transforms/simplify.h @@ -25,17 +25,16 @@ #define TVM_TIR_TRANSFORMS_SIMPLIFY_H_ #include -#include +#include namespace tvm { namespace tir { -/* \brief Simplifies the statement +/* \brief Simplifies the prim func * - * Applies the same behavior as the tir.transform.Simplify pass, but - * on a single statement, usable as a subroutine in other passes. + * Applies the same behavior as the tir.transform.Simplify pass. */ -Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer); +PrimFunc Simplify(PrimFunc stmt, arith::Analyzer* analyzer); } // namespace tir } // namespace tvm diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize.py b/tests/python/tir-schedule/test_tir_schedule_tensorize.py index 1891914bc06f..789d6be3ad0b 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize.py @@ -836,6 +836,124 @@ def tensorized_matmul_int64_shape( assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_matmul_int64_shape) verify_trace_roundtrip(sch=s, mod=matmul_int64_shape) +def _tir_packed_int_to_int_to_float(storage_nbit: int): + storage_dtype = "int" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype + mask = tir.const((1 << nbit) - 1, "int32") + unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask + return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + + return f_convert + +@T.prim_func +def decode_i4s_to_f16_desc(compressed: T.handle, decompressed: T.handle) -> None: + Compressed = T.match_buffer( + compressed, + [ + 1, + ], + dtype="int32", + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + 8, + ], + dtype="float16", + scope="local", + ) + + with T.block("root"): + T.reads(Compressed[0:1]) + T.writes(Decompressed[0:8]) + for i in T.grid(8): + with T.block("decode"): + vi = T.axis.remap("S", [i]) + Decompressed[vi] = _tir_packed_int_to_int_to_float(32)( + 4, + Compressed[vi // 8], + vi % 8, + dtype="float16", + ) + +@T.prim_func +def decode_i4s_to_f16_impl(compressed: T.handle, decompressed: T.handle) -> None: + Compressed = T.match_buffer( + compressed, + [ + 1, + ], + dtype="int32", + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + 8, + ], + dtype="float16", + scope="local", + ) + + with T.block("root"): + T.reads(Compressed[0:1]) + T.writes(Decompressed[0:8]) + T.call_extern( + "handle", + "test_decode_i4s_to_f16", + Compressed.data, + Decompressed.data, + 8, + ) + +tir.TensorIntrin.register("test_decode_i4s_to_f16_intrin", decode_i4s_to_f16_desc, decode_i4s_to_f16_impl) + +def test_tensorize_arith_simplification(): + # fmt: off + @T.prim_func + def decode_i4s_to_int32_to_f16(): + B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local") + B_local = T.alloc_buffer((16384, 2048), "int32", scope="local") + for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"): + for ax0_1 in T.thread_binding(2, thread="threadIdx.y"): + for ax1_0 in range(32): + for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): + for ax0, ax1 in T.grid(1, 8): + with T.block("B_decode_local"): + v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0) + v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1) + T.reads(B_local[v0, v1 // 8]) + T.writes(B_decode_local[v0, v1]) + B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28)) + + @T.prim_func + def tensorized_decode_i4s_to_int32_to_f16(): + B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local") + B_local = T.alloc_buffer((16384, 2048), "int32", scope="local") + for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"): + for ax0_1 in T.thread_binding(2, thread="threadIdx.y"): + for ax1_0 in range(32): + for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): + for ax0 in range(1): + with T.block("B_decode_local_o"): + v0_o = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0) + v1_o = T.axis.spatial(2048, ax1_0 * 64 + ax1_1) + T.reads(B_local[v0_o, v1_o]) + T.writes(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8]) + Compressed = T.match_buffer(B_local[v0_o, v1_o], (1,), "int32", scope="local") + Decompressed = T.match_buffer(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8], (8,), "float16", scope="local") + T.call_extern("handle", "test_decode_i4s_to_f16", Compressed.data, Decompressed.data, 8) + + s = tir.Schedule(decode_i4s_to_int32_to_f16, debug_mask="all") + update = s.get_block("B_decode_local") + ii = s.get_loops(update)[-1] + s.tensorize(ii, "test_decode_i4s_to_f16_intrin") + assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_decode_i4s_to_int32_to_f16) + verify_trace_roundtrip(sch=s, mod=decode_i4s_to_int32_to_f16) + if __name__ == "__main__": tvm.testing.main()