Skip to content

Commit

Permalink
[Relay] Add a unit test for structural equality (apache#9745)
Browse files Browse the repository at this point in the history
This is CORE-135 from the forums, which suggested structural equality
was deeply broken. But unable to repro. No harm including unit test.

(attempt 3)
  • Loading branch information
mbs-octoml authored and baoxinqi committed Dec 27, 2021
1 parent 9e50b06 commit 84c07fc
Showing 1 changed file with 44 additions and 28 deletions.
72 changes: 44 additions & 28 deletions tests/python/relay/test_ir_structural_equal_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
import numpy as np
import tvm
from tvm import te
from tvm import relay
from tvm.relay.testing import run_opt_pass

Expand Down Expand Up @@ -756,31 +755,48 @@ def get_fn(with_vid):
assert consistent_equal(get_fn(False), get_fn(False))


def test_lets():
shape = (5, 5)

def func1():
sb = relay.ScopeBuilder()
p0 = relay.var("p0", shape=shape)
p1 = relay.var("p1", shape=shape)
a0 = sb.let("a0", relay.add(p0, relay.const(1)))
a1 = sb.let("a1", relay.add(p1, relay.const(1)))
a2 = sb.let("a2", relay.add(a0, a1))
sb.ret(a2)
return relay.Function([p0, p1], sb.get())

def func2():
# Alpha conversion is structurally equal
sb = relay.ScopeBuilder()
p0 = relay.var("p0", shape=shape)
p1 = relay.var("p1", shape=shape)
a1 = sb.let("a1", relay.add(p0, relay.const(1)))
a0 = sb.let("a0", relay.add(p1, relay.const(1)))
a2 = sb.let("a2", relay.add(a1, a0))
sb.ret(a2)
return relay.Function([p0, p1], sb.get())

def func3():
# But changing the order of bindings is not structurally equal
# (even though algebraically equal)
sb = relay.ScopeBuilder()
p0 = relay.var("p0", shape=shape)
p1 = relay.var("p1", shape=shape)
a1 = sb.let("a1", relay.add(p1, relay.const(1)))
a0 = sb.let("a0", relay.add(p0, relay.const(1)))
a2 = sb.let("a2", relay.add(a1, a0))
sb.ret(a2)
return relay.Function([p0, p1], sb.get())

assert tvm.ir.structural_equal(func1(), func2())
assert not tvm.ir.structural_equal(func1(), func3())


if __name__ == "__main__":
test_fn_vid_map()
test_tensor_type_sequal()
test_incomplete_type_sequal()
test_constant_sequal()
test_type_node_sequal()
test_type_node_incompatible_sequal()
test_expr_node_incompatible_sequal()
test_func_type_sequal()
test_tuple_type_sequal()
test_type_relation_sequal()
test_type_call_sequal()
test_constant_sequal()
test_global_var_sequal()
test_tuple_sequal()
test_tuple_get_item_sequal()
test_function_sequal()
test_function_attr()
test_call_sequal()
test_let_sequal()
test_if_sequal()
test_constructor_sequal()
test_match_sequal()
test_op_sequal()
test_var_sequal()
test_graph_equal()
test_hash_unequal()
test_fn_attribute()
import sys
import pytest

sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 84c07fc

Please sign in to comment.