Skip to content

Commit

Permalink
[TIR] Enhance and fix tensorize schedule for some case (#16560)
Browse files Browse the repository at this point in the history
* support tensorize with simplified and call expr

* replace stmt simplifier with primfunc simplifier

* lint fix

* lint:remove white space

* lint: remove white space

* cpp lint fix

* lint: resolve include

* clang format lint fix
  • Loading branch information
LeiWang1999 authored Mar 8, 2024
1 parent 657880c commit 7b7677f
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 6 deletions.
24 changes: 24 additions & 0 deletions src/tir/schedule/ir_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,30 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) {
return equal;
}

bool TensorizeComparator::VisitExpr_(const CallNode* op, const PrimExpr& other) {
const auto* rhs = other.as<CallNode>();
if (!rhs->op.same_as(op->op)) return false;
if (op->dtype.code() != rhs->dtype.code()) {
if (assert_mode_) {
std::ostringstream os;
os << "CallNode data type codes do not match: op->dtype.code()=" << op->dtype.code()
<< " vs rhs->dtype.code()=" << rhs->dtype.code();
EmitError(os.str());
}
return false;
}
if (!CompareArray(op->args, rhs->args, &TensorizeComparator::VisitExpr)) {
if (assert_mode_) {
std::ostringstream os;
os << "CallNode iter_values do not match: op->iter_values=" << op->args
<< " vs rhs->iter_values=" << rhs->args;
EmitError(os.str());
}
return false;
}
return true;
}

bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) {
const auto* rhs = other.as<ForNode>();
if (!DefEqual(op->loop_var, rhs->loop_var)) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/ir_comparator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class TensorizeComparator : public ExprComparator, public StmtComparator {
bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override;
bool VisitStmt(const Stmt& n, const Stmt& other) override;

bool VisitExpr_(const CallNode* op, const PrimExpr& other) override;
bool VisitStmt_(const ForNode* op, const Stmt& other) override;
bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override;
bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
Expand Down
5 changes: 4 additions & 1 deletion src/tir/schedule/primitive/blockize_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <functional>

#include "../../transforms/simplify.h"
#include "../ir_comparator.h"
#include "../utils.h"

Expand Down Expand Up @@ -755,7 +756,9 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int
<< GetRef<Stmt>(sref->stmt);
throw;
}
PrimFunc intrin_desc = intrin->desc;

arith::Analyzer analyzer;
PrimFunc intrin_desc = Simplify(intrin->desc, &analyzer);
PrimFunc intrin_impl = DeepCopy(intrin->impl);

int index_dtype_bits = -1;
Expand Down
8 changes: 8 additions & 0 deletions src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
* \file simplify.cc
* \brief Statement simplifier based on analyzer
*/

#include "../../tir/transforms/simplify.h"

#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
Expand Down Expand Up @@ -339,6 +342,11 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
} // namespace arith

namespace tir {

PrimFunc Simplify(PrimFunc func, arith::Analyzer* analyzer) {
return arith::StmtSimplifier::Apply(std::move(func), analyzer);
}

namespace transform {

Pass Simplify() {
Expand Down
9 changes: 4 additions & 5 deletions src/tir/transforms/simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,16 @@
#define TVM_TIR_TRANSFORMS_SIMPLIFY_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/function.h>

namespace tvm {
namespace tir {

/* \brief Simplifies the statement
/* \brief Simplifies the prim func
*
* Applies the same behavior as the tir.transform.Simplify pass, but
* on a single statement, usable as a subroutine in other passes.
* Applies the same behavior as the tir.transform.Simplify pass.
*/
Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer);
PrimFunc Simplify(PrimFunc stmt, arith::Analyzer* analyzer);

} // namespace tir
} // namespace tvm
Expand Down
118 changes: 118 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,124 @@ def tensorized_matmul_int64_shape(
assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_matmul_int64_shape)
verify_trace_roundtrip(sch=s, mod=matmul_int64_shape)

def _tir_packed_int_to_int_to_float(storage_nbit: int):
storage_dtype = "int" + str(storage_nbit)

def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype
mask = tir.const((1 << nbit) - 1, "int32")
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))

return f_convert

@T.prim_func
def decode_i4s_to_f16_desc(compressed: T.handle, decompressed: T.handle) -> None:
Compressed = T.match_buffer(
compressed,
[
1,
],
dtype="int32",
scope="local",
)
Decompressed = T.match_buffer(
decompressed,
[
8,
],
dtype="float16",
scope="local",
)

with T.block("root"):
T.reads(Compressed[0:1])
T.writes(Decompressed[0:8])
for i in T.grid(8):
with T.block("decode"):
vi = T.axis.remap("S", [i])
Decompressed[vi] = _tir_packed_int_to_int_to_float(32)(
4,
Compressed[vi // 8],
vi % 8,
dtype="float16",
)

@T.prim_func
def decode_i4s_to_f16_impl(compressed: T.handle, decompressed: T.handle) -> None:
Compressed = T.match_buffer(
compressed,
[
1,
],
dtype="int32",
scope="local",
)
Decompressed = T.match_buffer(
decompressed,
[
8,
],
dtype="float16",
scope="local",
)

with T.block("root"):
T.reads(Compressed[0:1])
T.writes(Decompressed[0:8])
T.call_extern(
"handle",
"test_decode_i4s_to_f16",
Compressed.data,
Decompressed.data,
8,
)

tir.TensorIntrin.register("test_decode_i4s_to_f16_intrin", decode_i4s_to_f16_desc, decode_i4s_to_f16_impl)

def test_tensorize_arith_simplification():
# fmt: off
@T.prim_func
def decode_i4s_to_int32_to_f16():
B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
for ax1_0 in range(32):
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0, ax1 in T.grid(1, 8):
with T.block("B_decode_local"):
v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1)
T.reads(B_local[v0, v1 // 8])
T.writes(B_decode_local[v0, v1])
B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28))

@T.prim_func
def tensorized_decode_i4s_to_int32_to_f16():
B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
for ax1_0 in range(32):
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0 in range(1):
with T.block("B_decode_local_o"):
v0_o = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
v1_o = T.axis.spatial(2048, ax1_0 * 64 + ax1_1)
T.reads(B_local[v0_o, v1_o])
T.writes(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8])
Compressed = T.match_buffer(B_local[v0_o, v1_o], (1,), "int32", scope="local")
Decompressed = T.match_buffer(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8], (8,), "float16", scope="local")
T.call_extern("handle", "test_decode_i4s_to_f16", Compressed.data, Decompressed.data, 8)

s = tir.Schedule(decode_i4s_to_int32_to_f16, debug_mask="all")
update = s.get_block("B_decode_local")
ii = s.get_loops(update)[-1]
s.tensorize(ii, "test_decode_i4s_to_f16_intrin")
assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_decode_i4s_to_int32_to_f16)
verify_trace_roundtrip(sch=s, mod=decode_i4s_to_int32_to_f16)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 7b7677f

Please sign in to comment.