diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h index 55b3d938df0a..5a433f1d9820 100644 --- a/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h +++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h @@ -41,6 +41,8 @@ namespace mxnet { namespace op { namespace nnvm_to_onnx { +enum ConvDeconvType {Convolution, Deconvolution}; + using namespace nnvm; using namespace ::onnx; using int64 = ::google::protobuf::int64; @@ -48,8 +50,7 @@ using int64 = ::google::protobuf::int64; std::unordered_map GetPlaceholderShapes(const ShapeVector& shape_inputs, const nnvm::IndexedGraph& ig); -std::unordered_map GetPlaceholderDTypes(const DTypeVector& -dtype_inputs, +std::unordered_map GetPlaceholderDTypes(const DTypeVector& dtype_inputs, const nnvm::IndexedGraph& ig); std::unordered_map GetOutputLookup(const nnvm::IndexedGraph& ig); @@ -74,12 +75,25 @@ typedef void (*ConverterFunction)(NodeProto *node_proto, const nnvm::IndexedGraph &ig, const array_view &inputs); +template +void ConvDeconvConvertHelper(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs, + const ConvDeconvParam& param, + ConvDeconvType type); + // Forward declarations void ConvertConvolution(NodeProto *node_proto, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, const array_view &inputs); +void ConvertDeconvolution(NodeProto *node_proto, + const NodeAttrs &attrs, + const nnvm::IndexedGraph &ig, + const array_view &inputs); + void ConvertPooling(NodeProto *node_proto, const NodeAttrs &attrs, const nnvm::IndexedGraph &ig, @@ -158,6 +172,7 @@ static const std::unordered_map converter_map = {"BatchNorm", ConvertBatchNorm}, {"clip", ConvertClip}, {"Convolution", ConvertConvolution}, + {"Deconvolution", ConvertDeconvolution}, {"Concat", ConvertConcatenate}, {"Dropout", ConvertDropout}, {"elemwise_add", ConvertElementwiseAdd}, diff --git a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc index 6116f296e300..84580d0b05d0 100644 --- a/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc +++ b/src/operator/subgraph/tensorrt/nnvm_to_onnx.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include "../../../common/utils.h" #include "../../../ndarray/ndarray_function.h" @@ -170,20 +171,25 @@ std::string ConvertNnvmGraphToOnnx( return serialized_onnx_graph; } -void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs, - const nnvm::IndexedGraph& /*ig*/, - const array_view& /*inputs*/) { - const auto& conv_param = nnvm::get(attrs.parsed); - - node_proto->set_op_type("Conv"); +template +void ConvDeconvConvertHelper(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& /*ig*/, + const array_view& /*input*/, + const ConvDeconvParam& param, + ConvDeconvType type) { + if (type == ConvDeconvType::Convolution) { + node_proto->set_op_type("Conv"); + } else { + node_proto->set_op_type("ConvTranspose"); + } - const mxnet::TShape kernel = conv_param.kernel; - const mxnet::TShape stride = conv_param.stride; - const mxnet::TShape dilate = conv_param.dilate; - const mxnet::TShape pad = conv_param.pad; - const uint32_t num_group = conv_param.num_group; + const mxnet::TShape kernel = param.kernel; + const mxnet::TShape stride = param.stride; + const mxnet::TShape dilate = param.dilate; + const mxnet::TShape pad = param.pad; + const uint32_t num_group = param.num_group; // const bool no_bias = conv_param.no_bias; - const dmlc::optional layout = conv_param.layout; + const dmlc::optional layout = param.layout; // dilations AttributeProto* const dilations = node_proto->add_attribute(); @@ -226,8 +232,24 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs, for (const dim_t kval : stride) { strides->add_ints(static_cast(kval)); } +} + +void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + const auto& conv_param = nnvm::get(attrs.parsed); + ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, conv_param, + ConvDeconvType::Convolution); } // end ConvertConvolution +void ConvertDeconvolution(NodeProto* node_proto, const NodeAttrs& attrs, + const nnvm::IndexedGraph& ig, + const array_view& inputs) { + const auto& deconv_param = nnvm::get(attrs.parsed); + ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, deconv_param, + ConvDeconvType::Deconvolution); +} // end ConvertDeconvolution + void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs, const nnvm::IndexedGraph& /*ig*/, const array_view& /*inputs*/) { diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h b/src/operator/subgraph/tensorrt/tensorrt-inl.h index c175ac4d2aa3..7b0dcc1bd4a5 100644 --- a/src/operator/subgraph/tensorrt/tensorrt-inl.h +++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h @@ -89,6 +89,7 @@ class TensorrtSelector : public SubgraphSelector { "clip", "Concat", "Convolution", + "Deconvolution", "Dropout", "elemwise_add", "elemwise_sub", @@ -105,6 +106,7 @@ class TensorrtSelector : public SubgraphSelector { const std::unordered_set withWeightsOps = { "BatchNorm", "Convolution", + "Deconvolution", "FullyConnected" }; diff --git a/tests/python/tensorrt/test_tensorrt_deconv.py b/tests/python/tensorrt/test_tensorrt_deconv.py new file mode 100644 index 000000000000..ef567d1dae3c --- /dev/null +++ b/tests/python/tensorrt/test_tensorrt_deconv.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import mxnet as mx +from mxnet.test_utils import assert_almost_equal + +def get_params(): + arg_params = {} + aux_params = {} + arg_params["trt_bn_test_conv_weight"] = mx.nd.ones((1, 1, 3, 3)) + arg_params["trt_bn_test_deconv_weight"] = mx.nd.ones((1, 1, 3, 3)) + return arg_params, aux_params + +def get_symbol(): + data = mx.sym.Variable("data") + conv = mx.sym.Convolution(data=data, kernel=(3,3), no_bias=True, num_filter=1, num_group=1, + name="trt_bn_test_conv") + deconv = mx.sym.Deconvolution(data=conv, kernel=(3, 3), no_bias=True, num_filter=1, + num_group=1, name="trt_bn_test_deconv") + return deconv + +def test_deconvolution_produce_same_output_as_tensorrt(): + arg_params, aux_params = get_params() + arg_params_trt, aux_params_trt = get_params() + + sym = get_symbol() + sym_trt = get_symbol().get_backend_symbol("TensorRT") + + mx.contrib.tensorrt.init_tensorrt_params(sym_trt, arg_params_trt, aux_params_trt) + + executor = sym.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), grad_req='null', force_rebind=True) + executor.copy_params_from(arg_params, aux_params) + + executor_trt = sym_trt.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), grad_req='null', + force_rebind=True) + executor_trt.copy_params_from(arg_params_trt, aux_params_trt) + + input_data = mx.nd.random.uniform(low=0, high=1, shape=(1, 1, 3, 3)) + + y = executor.forward(is_train=False, data=input_data) + y_trt = executor_trt.forward(is_train=False, data=input_data) + + print(y[0].asnumpy()) + print(y_trt[0].asnumpy()) + assert_almost_equal(y[0].asnumpy(), y_trt[0].asnumpy(), 1e-4, 1e-4) + +if __name__ == '__main__': + import nose + nose.runmodule()