diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index e4ee14b62941..566d0ffa2bfa 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1311,6 +1311,33 @@ def FakeQuantizationToInteger(hard_fail=False, use_qat=False): return _ffi_api.FakeQuantizationToInteger(hard_fail, use_qat) +def FlattenAtrousConv(): + # pylint: disable=anomalous-backslash-in-string + """ + The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd + operations: + + .. code-block:: text + + x w + | | + s2b | + \\ / + conv2d + | + b2s + + and convert them into subgraphs with a convolution with the modified "dilation" and + recalculated "padding" parameters. + + Returns + ------- + ret : tvm.transform.Pass + The registered FlattenAtrousConv pass. + """ + return _ffi_api.FlattenAtrousConv() + + def ToMixedPrecision(mixed_precision_type="float16", missing_op_mode=1): """ Automatic mixed precision rewriter. Rewrite an FP32 relay graph into a version diff --git a/src/relay/qnn/utils.h b/src/relay/qnn/utils.h index b4841c8ddda8..18c592f2ed69 100644 --- a/src/relay/qnn/utils.h +++ b/src/relay/qnn/utils.h @@ -270,6 +270,12 @@ static inline std::vector GetFloatVectorFromConstant(const Expr& expr) { return vals; } +Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point, + Expr input_scale, Expr kernel_scale, Array strides, + Array padding, Array dilation, int groups, + IndexExpr channels, Array kernel_size, String data_layout, + String kernel_layout, String out_layout, DataType out_dtype); + } // namespace qnn } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/flatten_atrous_conv.cc b/src/relay/transforms/flatten_atrous_conv.cc new file mode 100644 index 000000000000..54e0f193cf8b --- /dev/null +++ b/src/relay/transforms/flatten_atrous_conv.cc @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/flatten_atrous_conv.cc + * \brief This transform flattens atrous convolution, which corresponds to the sequence of + * operations: "space_to_batch_nd"->"conv2d"->"batch_to_space_nd". + */ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../qnn/utils.h" +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +/* Description of FlattenAtrousConv + * + * The purpose of this pass is to find a sequence of space_to_batch_nd-conv2d-batch_to_space_nd + * operations: + * + * x w + * | | + * s2b | + * \ / + * conv2d + * | + * b2s + * + * and convert them into subgraphs with a convolution with the modified "dilation" and + * recalculated "padding" parameters. + */ + +using ExprSet = std::unordered_set; + +class FlattenAtrousConvSubgraphMutator { + public: + Expr MutateSubgraph(const Expr& expr) { + try { + const CallNode* b2s_node_ = expr.as(); + const CallNode* conv2d_node_ = b2s_node_->args[0].as(); + const CallNode* s2b_node_ = conv2d_node_->args[0].as(); + + ICHECK(b2s_node_ != nullptr); + const auto* b2s_attrs = b2s_node_->attrs.as(); + ICHECK(b2s_attrs != nullptr); + + Array dilation = {b2s_attrs->block_shape[0], b2s_attrs->block_shape[1]}; + + ICHECK(conv2d_node_ != nullptr); + const auto* conv2d_attrs = conv2d_node_->attrs.as(); + ICHECK(conv2d_attrs != nullptr); + + Array kernel_shape = conv2d_attrs->kernel_size; + PrimExpr kernel_h = kernel_shape[0]; + PrimExpr kernel_w = kernel_shape[1]; + + ICHECK(s2b_node_ != nullptr); + const auto* s2b_attrs = s2b_node_->attrs.as(); + ICHECK(s2b_attrs != nullptr); + + Expr data = s2b_node_->args[0]; + ICHECK(conv2d_attrs->data_layout == "NHWC"); + Array data_shape = transform::InferTypeLocal(data).as()->shape; + PrimExpr in_h = data_shape[1]; + PrimExpr in_w = data_shape[2]; + + PrimExpr dilation_h = dilation[0]; + PrimExpr dilation_w = dilation[1]; + + PrimExpr dilated_kernel_h = (kernel_h - 1) * dilation_h + 1; + PrimExpr dilated_kernel_w = (kernel_w - 1) * dilation_w + 1; + + Array strides = {1, 1}; + PrimExpr stride_h = strides[0]; + PrimExpr stride_w = strides[1]; + + auto _get_pad_pair = [](PrimExpr input1d, PrimExpr kernel1d, + PrimExpr stride1d) -> Array { + PrimExpr out1d = truncdiv((input1d + stride1d - 1), stride1d); + PrimExpr pad = topi::maximum(((out1d - 1) * stride1d + kernel1d - input1d), 0); + PrimExpr pad_before = truncdiv(pad, 2); + PrimExpr pad_after = pad - pad_before; + return {pad_before, pad_after}; + }; + + Array pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h); + Array pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w); + + Array padding = {pad_v[0], pad_h[0], pad_v[1], pad_h[1]}; + + Expr weight = conv2d_node_->args[1]; + + if (conv2d_node_->op == Op::Get("nn.conv2d")) { + return Conv2D(data, weight, strides, padding, dilation, conv2d_attrs->groups, + conv2d_attrs->channels, conv2d_attrs->kernel_size, conv2d_attrs->data_layout, + conv2d_attrs->kernel_layout, conv2d_attrs->out_layout, + conv2d_attrs->out_dtype); + } + + if (conv2d_node_->op == Op::Get("qnn.conv2d")) { + Expr input_zero_point = conv2d_node_->args[2]; + Expr kernel_zero_point = conv2d_node_->args[3]; + Expr input_scale = conv2d_node_->args[4]; + Expr kernel_scale = conv2d_node_->args[5]; + return qnn::MakeQnnConv2D(data, weight, input_zero_point, kernel_zero_point, input_scale, + kernel_scale, strides, padding, dilation, conv2d_attrs->groups, + conv2d_attrs->channels, conv2d_attrs->kernel_size, + conv2d_attrs->data_layout, conv2d_attrs->kernel_layout, + conv2d_attrs->out_layout, conv2d_attrs->out_dtype); + } + + DLOG(INFO) << "Ran into an unhandled convolution, skipping " << expr << std::endl; + return expr; + } catch (std::exception& e) { + DLOG(INFO) << "Ran into an error rewriting a subgraph, skipping " << expr << " with " + << e.what() << std::endl; + return expr; + } + } +}; + +class FlattenAtrousConvRewriter : public MixedModeMutator { + protected: + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (const CallNode* call_node = post.as()) { + if (ops_[op_iter_].count(call_node->op)) { + ++op_iter_; + if (op_iter_ == ops_.size()) { + op_iter_ = 0; + return FlattenAtrousConvSubgraphMutator().MutateSubgraph(post); + } + } else { + op_iter_ = 0; + } + } + return post; + } + + private: + size_t op_iter_ = 0; + const std::array ops_ = { + ExprSet{Op::Get("nn.space_to_batch_nd")}, + ExprSet{Op::Get("nn.conv2d"), Op::Get("qnn.conv2d")}, + ExprSet{Op::Get("nn.batch_to_space_nd")}, + }; +}; + +Expr FlattenAtrousConv(const Expr& expr, const IRModule& mod) { + return FlattenAtrousConvRewriter().Mutate(expr); +} + +namespace transform { + +Pass FlattenAtrousConv() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(FlattenAtrousConv(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "FlattenAtrousConv", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.FlattenAtrousConv").set_body_typed(FlattenAtrousConv); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_flatten_atrous_conv.py b/tests/python/relay/test_pass_flatten_atrous_conv.py new file mode 100644 index 000000000000..f6b3718e40e4 --- /dev/null +++ b/tests/python/relay/test_pass_flatten_atrous_conv.py @@ -0,0 +1,427 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-wildcard-import +import numpy as np +import pytest +import tvm +from tvm import relay + + +def compare_expected_fac(expr, expected_expr, args): + mod_def = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expr)) + mod_flat = tvm.relay.transform.FlattenAtrousConv()(mod_def) + mod_exp = tvm.relay.transform.InferType()(tvm.IRModule.from_expr(expected_expr)) + + assert expr is expected_expr or not tvm.ir.structural_equal(mod_def, mod_flat) + assert tvm.ir.structural_equal(mod_flat, mod_exp) + + result_def = ( + relay.create_executor("vm", mod=mod_def, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + result_flat = ( + relay.create_executor("vm", mod=mod_flat, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + result_exp = ( + relay.create_executor("vm", mod=mod_exp, device=tvm.cpu(), target="llvm") + .evaluate()(*args) + .numpy() + ) + + assert np.array_equal(result_def, result_flat) + assert np.array_equal(result_flat, result_exp) + + +def test_fac_block_shape_2(): + # pattern entry with block_shape=[2, 2] + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = relay.nn.conv2d( + data, + weight, + padding=[2, 2, 2, 2], + dilation=[2, 2], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_block_shape_4(): + # pattern entry with block_shape=[4, 4] + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[4, 4], paddings=[[4, 7], [4, 7]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op2, block_shape=[4, 4], crops=[[0, 3], [0, 3]]) + + expected_expr = relay.nn.conv2d( + data, + weight, + padding=[4, 4, 4, 4], + dilation=[4, 4], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_quantize(): + # quantize pattern entry + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="int8") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.qnn.op.conv2d( + op1, + weight, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(2.0), + kernel_scale=relay.const(1.0), + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = relay.qnn.op.conv2d( + data, + weight, + input_zero_point=relay.const(0), + kernel_zero_point=relay.const(0), + input_scale=relay.const(2.0), + kernel_scale=relay.const(1.0), + padding=[2, 2, 2, 2], + dilation=[2, 2], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_surrounding(): + # pattern entry with surrounding operations add + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op0 = relay.op.add(data, relay.const(1.0)) + op1 = relay.nn.space_to_batch_nd(op0, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + op3 = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + expr = relay.op.add(op3, relay.const(-1.0)) + + op0 = relay.op.add(data, relay.const(1.0)) + op1 = relay.nn.conv2d( + op0, + weight, + padding=[2, 2, 2, 2], + dilation=[2, 2], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expected_expr = relay.op.add(op1, relay.const(-1.0)) + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_several(): + # several pattern entries + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + op3 = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + op4 = relay.nn.space_to_batch_nd(op3, block_shape=[4, 4], paddings=[[4, 7], [4, 7]]) + op5 = relay.nn.conv2d( + op4, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op5, block_shape=[4, 4], crops=[[0, 3], [0, 3]]) + + op1 = relay.nn.conv2d( + data, + weight, + padding=[2, 2, 2, 2], + dilation=[2, 2], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + expected_expr = relay.nn.conv2d( + op1, + weight, + padding=[4, 4, 4, 4], + dilation=[4, 4], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test__fac_only_s2b_conv(): + # negative case, only operations space_to_batch_nd-conv2d + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + expr = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_only_s2b(): + # negative case, only operation space_to_batch_nd + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + expr = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_only_conv_b2s(): + # negative case, only operations conv2d-batch_to_space_nd + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.conv2d( + data, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op1, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_only_b2s(): + # negative case, only operation batch_to_space_nd + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + expr = relay.nn.batch_to_space_nd(data, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_op_btwn_s2b_conv(): + # negative case, add operation between space_to_batch_nd-conv2d + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op_1_5 = relay.op.add(op1, relay.const(1.0)) + op2 = relay.nn.conv2d( + op_1_5, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + expr = relay.nn.batch_to_space_nd(op2, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +def test_fac_op_btwn_conv_b2s(): + # negative case, add operation between conv2d-batch_to_space_nd + shape_x = [1, 5, 5, 4] + shape_w = [3, 3, 4, 1] + + x_np = np.random.randint(-128, 127, size=shape_x, dtype="int8").astype("float32") + w_np = np.random.randint(-128, 127, size=shape_w, dtype="int8").astype("float32") + + weight = relay.const(w_np) + data = relay.var("data", shape=shape_x, dtype="float32") + op1 = relay.nn.space_to_batch_nd(data, block_shape=[2, 2], paddings=[[2, 3], [2, 3]]) + op2 = relay.nn.conv2d( + op1, + weight, + padding=[0, 0, 0, 0], + groups=4, + channels=4, + kernel_size=[3, 3], + data_layout="NHWC", + kernel_layout="HWOI", + ) + op_2_5 = relay.op.add(op2, relay.const(1.0)) + expr = relay.nn.batch_to_space_nd(op_2_5, block_shape=[2, 2], crops=[[0, 1], [0, 1]]) + + expected_expr = expr + + compare_expected_fac(expr, expected_expr, [x_np]) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__] + sys.argv[1:]))