Skip to content

Commit

Permalink
[Unity] Alter op impl handling empty transform for output (#16331)
Browse files Browse the repository at this point in the history
Alterop impl, handling empty transform for output
  • Loading branch information
rasagna-quic authored Jan 4, 2024
1 parent 0cf5f47 commit 7dfc863
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/relax/transform/alter_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ class AlterOpImplMutator : public ExprMutator {

/*! \brief Returns the TensorStructInfo after applying the \p transform on its shape */
StructInfo UpdateStructInfo(const TensorStructInfo& tensor_sinfo, const IndexMap& transform) {
if (transform.get() == nullptr) return tensor_sinfo;
auto shape = GetShapeFromTensorStructInfo(tensor_sinfo);
arith::Analyzer analyzer;
auto new_shape = transform->MapShape(shape, &analyzer);
Expand Down
100 changes: 100 additions & 0 deletions tests/python/relax/test_transform_alter_op_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,5 +472,105 @@ def add_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), "float32"),
)


def test_reshape():
@I.ir_module
class Before:
@T.prim_func(private=True)
def reshape(
A: T.Buffer((T.int64(850), T.int64(2048)), "float16"),
T_reshape: T.Buffer((T.int64(850), T.int64(1), T.int64(2048)), "float16"),
):
T.func_attr({"operator_name": "relax.reshape"})
for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(
A[
(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % T.int64(850),
v_ax2 % T.int64(2048),
]
)
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
T_reshape[v_ax0, v_ax1, v_ax2] = A[
(v_ax2 // T.int64(2048) + v_ax0 + v_ax1) % T.int64(850),
v_ax2 % T.int64(2048),
]

@R.function
def main(
x: R.Tensor((850, 2048), dtype="float16")
) -> R.Tensor((850, 1, 2048), dtype="float16"):
cls = Before
with R.dataflow():
lv = R.call_tir(
cls.reshape, (x,), out_sinfo=R.Tensor((850, 1, 2048), dtype="float16")
)
gv: R.Tensor((850, 1, 2048), dtype="float16") = lv
R.output(gv)
return gv

@I.ir_module
class Expected:
@T.prim_func(private=True)
def relax_reshape_replacement(
A: T.Buffer((T.int64(850), T.int64(2), T.int64(1024)), "float16"),
T_reshape: T.Buffer((T.int64(850), T.int64(1), T.int64(2048)), "float16"),
):
T.func_attr({"operator_name": "relax.reshape"})
for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
T_reshape[v_ax0, v_ax1, v_ax2] = A[
v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)
]

@R.function
def main(
x: R.Tensor((850, 2048), dtype="float16")
) -> R.Tensor((850, 1, 2048), dtype="float16"):
cls = Expected
with R.dataflow():
lv: R.Tensor((850, 2, 1024), dtype="float16") = R.layout_transform(
x,
index_map=T.index_map(lambda i, j: (i, j // 1024, j % 1024)),
pad_value=None,
axis_separators=[],
)
lv_1 = R.call_tir(
cls.relax_reshape_replacement,
(lv,),
out_sinfo=R.Tensor((850, 1, 2048), dtype="float16"),
)
gv: R.Tensor((850, 1, 2048), dtype="float16") = lv_1
R.output(gv)
return gv

@T.prim_func(private=True)
def reshape_new(
A: T.Buffer((T.int64(850), T.int64(2), T.int64(1024)), "float16"),
T_reshape: T.Buffer((T.int64(850), T.int64(1), T.int64(2048)), "float16"),
):
for ax0, ax1, ax2 in T.grid(T.int64(850), T.int64(1), T.int64(2048)):
with T.block("T_reshape"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(A[v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)])
T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
T_reshape[v_ax0, v_ax1, v_ax2] = A[
v_ax0, v_ax2 // T.int64(1024), v_ax2 % T.int64(1024)
]

# fmt: on
index_map = lambda i, j: (i, j // 1024, j % 1024)
_check(
Before,
Expected,
operator_name="relax.reshape",
replacement_primfunc=reshape_new,
layout_changes=[index_map, None],
)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 7dfc863

Please sign in to comment.