From 682a62c54de36373d4af7c315845d501bb56e3c9 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sat, 9 Mar 2024 21:12:28 +0800 Subject: [PATCH] [TIR] Support Vector Reinterpret Calls (#16673) This PR adds support for vector reinterpret calls in TIR. --- src/tir/transforms/vectorize_loop.cc | 14 ++++++++- .../test_tir_transform_vectorize.py | 31 +++++++++++++------ 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index fe589bede612..57536422cf64 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -32,7 +32,6 @@ #include #include -#include #include namespace tvm { @@ -319,6 +318,17 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype.with_lanes(lanes), op->op, {cond, t, f}); } } + // Reinterpret expr + PrimExpr MutateReinterpretExpr_(const CallNode* op) { + ICHECK(op->op.same_as(builtin::reinterpret())); + PrimExpr value = this->VisitExpr(op->args[0]); + if (value.same_as(op->args[0])) { + return GetRef(op); + } else { + int lanes = value.dtype().lanes(); + return Call(op->dtype.with_lanes(lanes), op->op, {value}); + } + } // Call PrimExpr VisitExpr_(const CallNode* op) final { if (op->op.same_as(builtin::if_then_else())) { @@ -337,6 +347,8 @@ class Vectorizer : public StmtMutator, public ExprFunctor mutated_value = MutateArray(value, &lane); Array new_args{op->args[0], op->args[1], op->args[2], mutated_value[0]}; return Call(op->dtype.with_lanes(lane), op->op, new_args); + } else if (op->op.same_as(builtin::reinterpret())) { + return MutateReinterpretExpr_(op); } auto optional_op = op->op.as(); bool vectorizable = optional_op && op_vectorizable_.get(optional_op.value(), false); diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py b/tests/python/tir-transform/test_tir_transform_vectorize.py index 2448fffe8929..7d0fac242307 100644 --- a/tests/python/tir-transform/test_tir_transform_vectorize.py +++ b/tests/python/tir-transform/test_tir_transform_vectorize.py @@ -15,7 +15,10 @@ # specific language governing permissions and limitations # under the License. import tvm +import tvm.testing from tvm import te +from tvm.script import ir as I +from tvm.script import tir as T def test_vectorize_loop(): @@ -226,13 +229,23 @@ def test_vectorize_dtype_mismatch(): tvm.lower(s, [A], "llvm", simple_mode=True) +def test_vectorize_with_reinterpret(): + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): + for i in T.vectorized(0, 16): + B[i] = T.reinterpret("float32", A[i]) + + @I.ir_module + class After: + @T.prim_func + def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")): + B[0:16] = T.reinterpret("float32x16", A[0:16]) + + mod = tvm.tir.transform.VectorizeLoop()(Before) + tvm.ir.assert_structural_equal(mod, After) + + if __name__ == "__main__": - test_vectorize_vector() - test_vectorize_with_if() - test_vectorize_loop() - test_vectorize_if_then_else() - test_vectorize_with_le_cond() - test_vectorize_with_ge_cond() - test_vectorize_let() - test_vectorize_while_fail() - test_vectorize_dtype_mismatch() + tvm.testing.main()