From 7dfc863df8b6c9227a03547e5a0bf23f44c3f62d Mon Sep 17 00:00:00 2001 From: Venkat Rasagna Komatireddy <89959097+rasagna-quic@users.noreply.github.com> Date: Thu, 4 Jan 2024 12:44:07 +0530 Subject: [PATCH] [Unity] Alter op impl handling empty transform for output (#16331) Alterop impl, handling empty transform for output --- src/relax/transform/alter_op_impl.cc | 1 + .../relax/test_transform_alter_op_impl.py | 100 ++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/src/relax/transform/alter_op_impl.cc b/src/relax/transform/alter_op_impl.cc index 98d64dd7a80b..8b5518212cc8 100644 --- a/src/relax/transform/alter_op_impl.cc +++ b/src/relax/transform/alter_op_impl.cc @@ -324,6 +324,7 @@ class AlterOpImplMutator : public ExprMutator { /*! \brief Returns the TensorStructInfo after applying the \p transform on its shape */ StructInfo UpdateStructInfo(const TensorStructInfo& tensor_sinfo, const IndexMap& transform) { + if (transform.get() == nullptr) return tensor_sinfo; auto shape = GetShapeFromTensorStructInfo(tensor_sinfo); arith::Analyzer analyzer; auto new_shape = transform->MapShape(shape, &analyzer); diff --git a/tests/python/relax/test_transform_alter_op_impl.py b/tests/python/relax/test_transform_alter_op_impl.py index 3cbba9a031b7..f2bad31f2116 100644 --- a/tests/python/relax/test_transform_alter_op_impl.py +++ b/tests/python/relax/test_transform_alter_op_impl.py @@ -472,5 +472,105 @@ def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"), ) +def test_reshape(): + @I.ir_module + class Before: + @T.prim_func(private=True) + def reshape( + A: T.Buffer((T.int64(850), T.int64(2048)), "float16"), + T_reshape: T.Buffer((T.int64(850), T.int64(1), T.int64(2048)), "float16"), + ): + T.func_attr({"operator_name": "relax.reshape"}) + for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + A[ + (v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % T.int64(850), + v_ax2 % T.int64(2048), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = A[ + (v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % T.int64(850), + v_ax2 % T.int64(2048), + ] + + @R.function + def main( + x: R.Tensor((850, 2048), dtype="float16") + ) -> R.Tensor((850, 1, 2048), dtype="float16"): + cls = Before + with R.dataflow(): + lv = R.call_tir( + cls.reshape, (x,), out_sinfo=R.Tensor((850, 1, 2048), dtype="float16") + ) + gv: R.Tensor((850, 1, 2048), dtype="float16") = lv + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func(private=True) + def relax_reshape_replacement( + A: T.Buffer((T.int64(850), T.int64(2), T.int64(1024)), "float16"), + T_reshape: T.Buffer((T.int64(850), T.int64(1), T.int64(2048)), "float16"), + ): + T.func_attr({"operator_name": "relax.reshape"}) + for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = A[ + v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024) + ] + + @R.function + def main( + x: R.Tensor((850, 2048), dtype="float16") + ) -> R.Tensor((850, 1, 2048), dtype="float16"): + cls = Expected + with R.dataflow(): + lv: R.Tensor((850, 2, 1024), dtype="float16") = R.layout_transform( + x, + index_map=T.index_map(lambda i, j: (i, j // 1024, j % 1024)), + pad_value=None, + axis_separators=[], + ) + lv_1 = R.call_tir( + cls.relax_reshape_replacement, + (lv,), + out_sinfo=R.Tensor((850, 1, 2048), dtype="float16"), + ) + gv: R.Tensor((850, 1, 2048), dtype="float16") = lv_1 + R.output(gv) + return gv + + @T.prim_func(private=True) + def reshape_new( + A: T.Buffer((T.int64(850), T.int64(2), T.int64(1024)), "float16"), + T_reshape: T.Buffer((T.int64(850), T.int64(1), T.int64(2048)), "float16"), + ): + for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(A[v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)]) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = A[ + v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024) + ] + + # fmt: on + index_map = lambda i, j: (i, j // 1024, j % 1024) + _check( + Before, + Expected, + operator_name="relax.reshape", + replacement_primfunc=reshape_new, + layout_changes=[index_map, None], + ) + + if __name__ == "__main__": tvm.testing.main()