diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 3950c02c08a4a..1a2d2470e0b7f 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -696,6 +696,28 @@ def test_match_dominator(): assert diamond.match(out) +def test_match_dominator2(): + # Pattern + conv2d_pat = is_op("nn.conv2d")(wildcard(), wildcard()) + eltwise_pat = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(None) + broadcast_pat = (wildcard().has_attr({"TOpPattern": K_BROADCAST}))(None) + path_pat = eltwise_pat | broadcast_pat + injective_pat = (wildcard().has_attr({"TOpPattern": K_INJECTIVE}))(wildcard()) + pattern = injective_pat.dominates(conv2d_pat, path_pat) + + # Graph + inp = relay.var("input") + weight = relay.var("weight") + bias = relay.var("bias") + conv2d = relay.op.nn.conv2d(inp, weight) + bias_add = relay.op.nn.bias_add(conv2d, bias) + relu = relay.op.nn.relu(bias_add) + reshape = relay.op.reshape(relu, newshape=[-1, 2, 8]) + + # Check + assert pattern.match(reshape) + + def test_not_match_dominator(): is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()) is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())