Skip to content

Commit

Permalink
[Relax] Unit-test for structural equal of recursive function
Browse files Browse the repository at this point in the history
A follow-up PR to #16756, adding an
explicit unit test for `tvm.ir.assert_structural_equal` of two
distinct recursive functions.
  • Loading branch information
Lunderberg committed Mar 26, 2024
1 parent bf2d43e commit b5dd002
Showing 1 changed file with 65 additions and 0 deletions.
65 changes: 65 additions & 0 deletions tests/python/relax/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,5 +206,70 @@ def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
tvm.ir.assert_structural_equal(func_1, func_2)


def test_structural_equal_with_distinct_recursive_lambda_function():
"""A recursive lambda function may be checked for structural equality
Like `test_structural_equal_with_recursive_lambda_function`, but
comparing between two distinct functions.
"""

@R.function(private=True)
def func_a(n: R.Prim("int64")):
@R.function
def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
i = T.int64()
if R.prim_value(i == 0):
output = R.prim_value(T.int64(0))
# ^
# The first mismatch is here ^
else:
remainder_relax = recursive_lambda(R.prim_value(i - 1))
remainder_tir = T.int64()
_ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir))
output = R.prim_value(i + remainder_tir)
return output

return recursive_lambda(n)

@R.function(private=True)
def func_b(n: R.Prim("int64")):
@R.function
def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
i = T.int64()
if R.prim_value(i == 0):
output = R.prim_value(T.int64(1))
# ^
# The first mismatch is here ^
else:
remainder_relax = recursive_lambda(R.prim_value(i - 1))
remainder_tir = T.int64()
_ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir))
output = R.prim_value(i * remainder_tir)
return output

return recursive_lambda(n)

# The path to the first mismatch, which should appear within the
# error message.
mismatch_path = [
"<root>",
"body",
"blocks[0]",
"bindings[0]",
"value",
"body",
"blocks[0]",
"bindings[0]",
"value",
"true_branch",
"body",
"value",
"value",
]

with pytest.raises(ValueError, match=re.escape(".".join(mismatch_path))):
tvm.ir.assert_structural_equal(func_a, func_b)


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit b5dd002

Please sign in to comment.