Skip to content

Commit

Permalink
Add FlattenAtrousConv pass into the default optimize pipeline. (apach…
Browse files Browse the repository at this point in the history
  • Loading branch information
Icemist authored and Sergey Shtin committed May 17, 2022
1 parent d3a9c21 commit e1396ad
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 0 deletions.
9 changes: 9 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,15 @@ TVM_DLL Pass ManifestLifetimes();
*/
TVM_DLL Pass PlanDevices(CompilationConfig config);

/*!
* \brief This transform flattens atrous convolution, which corresponds to the sequence of
* operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd" and convert them into subgraphs
* with a convolution with the modified "dilation" and recalculated "padding" parameters.
*
* \return The pass.
*/
TVM_DLL Pass FlattenAtrousConv();

} // namespace transform

/*!
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ Array<Pass> GetPassPrefix(bool is_homegeneous, bool is_vm) {
// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());

pass_seqs.push_back(transform::FlattenAtrousConv());
return pass_seqs;
}

Expand Down
45 changes: 45 additions & 0 deletions tests/python/relay/test_pass_flatten_atrous_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest
import tvm
from tvm import relay
from tvm.contrib import graph_executor


def compare_expected_fac(expr, expected_expr, args):
Expand Down Expand Up @@ -421,6 +422,50 @@ def test_fac_op_btwn_conv_b2s():
compare_expected_fac(expr, expected_expr, [x_np])


def test_fac_relay_build():
# Check the default optimize pipeline
shape_x = [1, 5, 5, 4]
shape_w = [3, 3, 4, 1]

x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32")
w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32")

weight = relay.const(w_np)
data = relay.var("data", shape=shape_x, dtype="float32")
op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]])
op2 = relay.nn.conv2d(
op1,
weight,
padding=[0, 0, 0, 0],
groups=4,
channels=4,
kernel_size=[3, 3],
data_layout="NHWC",
kernel_layout="HWOI",
)
expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]])

mod_def = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expr))
result_def = (
relay.create_executor("vm", mod=mod_def, device=tvm.cpu(), target="llvm")
.evaluate()(x_np)
.numpy()
)

graph, lib, params = relay.build(mod_def, "llvm", params=None)
rt_mod = graph_executor.create(graph, lib, device=tvm.cpu())
rt_mod.set_input("data", x_np)
rt_mod.set_input(**params)
rt_mod.run()
result_flat = rt_mod.get_output(0).numpy()

assert "space_to_batch_nd" not in graph
assert "conv2d" in graph
assert "batch_to_space_nd" not in graph

assert np.array_equal(result_def, result_flat)


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit e1396ad

Please sign in to comment.