Skip to content

Commit

Permalink
add testcase
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghuibin0 committed May 11, 2024
1 parent b3889af commit 399bcb3
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,28 @@ def test_match_dominator():
assert diamond.match(out)


def test_match_dominator2():
# Pattern
conv2d_pat = is_op("nn.conv2d")(wildcard(), wildcard())
eltwise_pat = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(None)
broadcast_pat = (wildcard().has_attr({"TOpPattern": K_BROADCAST}))(None)
path_pat = (eltwise_pat | broadcast_pat)
injective_pat = (wildcard().has_attr({"TOpPattern": K_INJECTIVE}))(wildcard())
pattern = injective_pat.dominates(conv2d_pat, path_pat)

# Graph
inp = relay.var("input")
weight = relay.var("weight")
bias = relay.var("bias")
conv2d = relay.op.nn.conv2d(inp, weight)
bias_add = relay.op.nn.bias_add(conv2d, bias)
relu = relay.op.nn.relu(bias_add)
reshape = relay.op.reshape(relu, newshape=[-1, 2, 8])

# Check
assert pattern.match(reshape)


def test_not_match_dominator():
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
Expand Down

0 comments on commit 399bcb3

Please sign in to comment.