From b5dd002abfb74ca1f007193b26da5095e08304c3 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 26 Mar 2024 08:41:16 -0500 Subject: [PATCH] [Relax] Unit-test for structural equal of recursive function A follow-up PR to https://github.com/apache/tvm/pull/16756, adding an explicit unit test for `tvm.ir.assert_structural_equal` of two distinct recursive functions. --- tests/python/relax/test_utils.py | 65 ++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 9abc53484b7f..41b0e714d1d0 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -206,5 +206,70 @@ def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"): tvm.ir.assert_structural_equal(func_1, func_2) +def test_structural_equal_with_distinct_recursive_lambda_function(): + """A recursive lambda function may be checked for structural equality + + Like `test_structural_equal_with_recursive_lambda_function`, but + comparing between two distinct functions. + """ + + @R.function(private=True) + def func_a(n: R.Prim("int64")): + @R.function + def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"): + i = T.int64() + if R.prim_value(i == 0): + output = R.prim_value(T.int64(0)) + # ^ + # The first mismatch is here ^ + else: + remainder_relax = recursive_lambda(R.prim_value(i - 1)) + remainder_tir = T.int64() + _ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir)) + output = R.prim_value(i + remainder_tir) + return output + + return recursive_lambda(n) + + @R.function(private=True) + def func_b(n: R.Prim("int64")): + @R.function + def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"): + i = T.int64() + if R.prim_value(i == 0): + output = R.prim_value(T.int64(1)) + # ^ + # The first mismatch is here ^ + else: + remainder_relax = recursive_lambda(R.prim_value(i - 1)) + remainder_tir = T.int64() + _ = R.match_cast(remainder_relax, R.Prim(value=remainder_tir)) + output = R.prim_value(i * remainder_tir) + return output + + return recursive_lambda(n) + + # The path to the first mismatch, which should appear within the + # error message. + mismatch_path = [ + "", + "body", + "blocks[0]", + "bindings[0]", + "value", + "body", + "blocks[0]", + "bindings[0]", + "value", + "true_branch", + "body", + "value", + "value", + ] + + with pytest.raises(ValueError, match=re.escape(".".join(mismatch_path))): + tvm.ir.assert_structural_equal(func_a, func_b) + + if __name__ == "__main__": pytest.main([__file__])