Skip to content

Commit

Permalink
Test for element-wise
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Apr 2, 2022
1 parent 51b8ac3 commit 189e8ae
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions tests/python/relax/test_transform_annotate_tir_op_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,5 +176,29 @@ def add_with_unit_dim_len_broadcast(
assert new_mod["add_with_unit_dim_len_broadcast"].attrs["op_pattern"] == 1


def test_annotate_opkind_add_element_wise_with_unit_shape():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def add_with_unit_dim_len_element_wise(
rxplaceholder_2: T.Buffer[(64, 112, 112), "float32"],
rxplaceholder_3: T.Buffer[(1, 64, 112, 112, 1, 1), "float32"],
T_add_1: T.Buffer[(64, 112, 112), "float32"],
) -> None:
T.func_attr({"global_symbol": "add5", "tir.noalias": True})
for i0, i1, i2 in T.grid(64, 112, 112):
with T.block("T_add"):
ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
T.reads(rxplaceholder_2[ax0, ax1, ax2], rxplaceholder_3[0, ax0, ax1, ax2, 0, 0])
T.writes(T_add_1[ax0, ax1, ax2])
T_add_1[ax0, ax1, ax2] = (
rxplaceholder_2[ax0, ax1, ax2] + rxplaceholder_3[0, ax0, ax1, ax2, 0, 0]
)

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["add_with_unit_dim_len_element_wise"].attrs["op_pattern"] == 0


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 189e8ae

Please sign in to comment.