diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index ea3a5dba6bf7..4a6b06f14f94 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -494,6 +494,15 @@ TVM_DLL Pass ManifestLifetimes(); */ TVM_DLL Pass PlanDevices(CompilationConfig config); +/*! + * \brief This transform flattens atrous convolution, which corresponds to the sequence of + * operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd" and convert them into subgraphs + * with a convolution with the modified "dilation" and recalculated "padding" parameters. + * + * \return The pass. + */ +TVM_DLL Pass FlattenAtrousConv(); + } // namespace transform /*! diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 1cc726c59f65..2bddf7556601 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -262,6 +262,8 @@ Array GetPassPrefix(bool is_homegeneous, bool is_vm) { // Fast math optimizations. pass_seqs.push_back(transform::FastMath()); pass_seqs.push_back(transform::FoldConstant()); + + pass_seqs.push_back(transform::FlattenAtrousConv()); return pass_seqs; } diff --git a/tests/python/relay/test_pass_flatten_atrous_conv.py b/tests/python/relay/test_pass_flatten_atrous_conv.py index f6b3718e40e4..a3d3eb94aeec 100644 --- a/tests/python/relay/test_pass_flatten_atrous_conv.py +++ b/tests/python/relay/test_pass_flatten_atrous_conv.py @@ -19,6 +19,7 @@ import pytest import tvm from tvm import relay +from tvm.contrib import graph_executor def compare_expected_fac(expr, expected_expr, args): @@ -421,6 +422,50 @@ def test_fac_op_btwn_conv_b2s(): compare_expected_fac(expr, expected_expr, [x_np]) +def test_fac_relay_build(): + # Check the default optimize pipeline + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + mod_def = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expr)) + result_def = ( + relay.create_executor("vm", mod=mod_def, device=tvm.cpu(), target="llvm") + .evaluate()(x_np) + .numpy() + ) + + graph, lib, params = relay.build(mod_def, "llvm", params=None) + rt_mod = graph_executor.create(graph, lib, device=tvm.cpu()) + rt_mod.set_input("data", x_np) + rt_mod.set_input(**params) + rt_mod.run() + result_flat = rt_mod.get_output(0).numpy() + + assert "space_to_batch_nd" not in graph + assert "conv2d" in graph + assert "batch_to_space_nd" not in graph + + assert np.array_equal(result_def, result_flat) + + if __name__ == "__main__": import sys