Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.5.x] add deconv in TRT subgraph (#15666) #16043

Merged
merged 1 commit into from
Aug 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions src/operator/subgraph/tensorrt/nnvm_to_onnx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,16 @@ namespace mxnet {
namespace op {
namespace nnvm_to_onnx {

enum ConvDeconvType {Convolution, Deconvolution};

using namespace nnvm;
using namespace ::onnx;
using int64 = ::google::protobuf::int64;

std::unordered_map<std::string, mxnet::TShape> GetPlaceholderShapes(const ShapeVector& shape_inputs,
const nnvm::IndexedGraph& ig);

std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector&
dtype_inputs,
std::unordered_map<std::string, int> GetPlaceholderDTypes(const DTypeVector& dtype_inputs,
const nnvm::IndexedGraph& ig);

std::unordered_map<std::string, uint32_t> GetOutputLookup(const nnvm::IndexedGraph& ig);
Expand All @@ -74,12 +75,25 @@ typedef void (*ConverterFunction)(NodeProto *node_proto,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);

template <class ConvDeconvParam>
void ConvDeconvConvertHelper(NodeProto *node_proto,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs,
const ConvDeconvParam& param,
ConvDeconvType type);

// Forward declarations
void ConvertConvolution(NodeProto *node_proto,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);

void ConvertDeconvolution(NodeProto *node_proto,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
const array_view<IndexedGraph::NodeEntry> &inputs);

void ConvertPooling(NodeProto *node_proto,
const NodeAttrs &attrs,
const nnvm::IndexedGraph &ig,
Expand Down Expand Up @@ -158,6 +172,7 @@ static const std::unordered_map<std::string, ConverterFunction> converter_map =
{"BatchNorm", ConvertBatchNorm},
{"clip", ConvertClip},
{"Convolution", ConvertConvolution},
{"Deconvolution", ConvertDeconvolution},
{"Concat", ConvertConcatenate},
{"Dropout", ConvertDropout},
{"elemwise_add", ConvertElementwiseAdd},
Expand Down
46 changes: 34 additions & 12 deletions src/operator/subgraph/tensorrt/nnvm_to_onnx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <mxnet/base.h>
#include <nnvm/graph.h>
#include <nnvm/pass_functions.h>
#include <operator/nn/deconvolution-inl.h>

#include "../../../common/utils.h"
#include "../../../ndarray/ndarray_function.h"
Expand Down Expand Up @@ -170,20 +171,25 @@ std::string ConvertNnvmGraphToOnnx(
return serialized_onnx_graph;
}

void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
const nnvm::IndexedGraph& /*ig*/,
const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);

node_proto->set_op_type("Conv");
template <class ConvDeconvParam>
void ConvDeconvConvertHelper(NodeProto* node_proto, const NodeAttrs& attrs,
const nnvm::IndexedGraph& /*ig*/,
const array_view<IndexedGraph::NodeEntry>& /*input*/,
const ConvDeconvParam& param,
ConvDeconvType type) {
if (type == ConvDeconvType::Convolution) {
node_proto->set_op_type("Conv");
} else {
node_proto->set_op_type("ConvTranspose");
}

const mxnet::TShape kernel = conv_param.kernel;
const mxnet::TShape stride = conv_param.stride;
const mxnet::TShape dilate = conv_param.dilate;
const mxnet::TShape pad = conv_param.pad;
const uint32_t num_group = conv_param.num_group;
const mxnet::TShape kernel = param.kernel;
const mxnet::TShape stride = param.stride;
const mxnet::TShape dilate = param.dilate;
const mxnet::TShape pad = param.pad;
const uint32_t num_group = param.num_group;
// const bool no_bias = conv_param.no_bias;
const dmlc::optional<int> layout = conv_param.layout;
const dmlc::optional<int> layout = param.layout;

// dilations
AttributeProto* const dilations = node_proto->add_attribute();
Expand Down Expand Up @@ -226,8 +232,24 @@ void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
for (const dim_t kval : stride) {
strides->add_ints(static_cast<int64>(kval));
}
}

void ConvertConvolution(NodeProto* node_proto, const NodeAttrs& attrs,
const nnvm::IndexedGraph& ig,
const array_view<IndexedGraph::NodeEntry>& inputs) {
const auto& conv_param = nnvm::get<op::ConvolutionParam>(attrs.parsed);
ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, conv_param,
ConvDeconvType::Convolution);
} // end ConvertConvolution

void ConvertDeconvolution(NodeProto* node_proto, const NodeAttrs& attrs,
const nnvm::IndexedGraph& ig,
const array_view<IndexedGraph::NodeEntry>& inputs) {
const auto& deconv_param = nnvm::get<op::DeconvolutionParam>(attrs.parsed);
ConvDeconvConvertHelper(node_proto, attrs, ig, inputs, deconv_param,
ConvDeconvType::Deconvolution);
} // end ConvertDeconvolution

void ConvertPooling(NodeProto* node_proto, const NodeAttrs& attrs,
const nnvm::IndexedGraph& /*ig*/,
const array_view<IndexedGraph::NodeEntry>& /*inputs*/) {
Expand Down
2 changes: 2 additions & 0 deletions src/operator/subgraph/tensorrt/tensorrt-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class TensorrtSelector : public SubgraphSelector {
"clip",
"Concat",
"Convolution",
"Deconvolution",
"Dropout",
"elemwise_add",
"elemwise_sub",
Expand All @@ -105,6 +106,7 @@ class TensorrtSelector : public SubgraphSelector {
const std::unordered_set<std::string> withWeightsOps = {
"BatchNorm",
"Convolution",
"Deconvolution",
"FullyConnected"
};

Expand Down
63 changes: 63 additions & 0 deletions tests/python/tensorrt/test_tensorrt_deconv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import mxnet as mx
from mxnet.test_utils import assert_almost_equal

def get_params():
arg_params = {}
aux_params = {}
arg_params["trt_bn_test_conv_weight"] = mx.nd.ones((1, 1, 3, 3))
arg_params["trt_bn_test_deconv_weight"] = mx.nd.ones((1, 1, 3, 3))
return arg_params, aux_params

def get_symbol():
data = mx.sym.Variable("data")
conv = mx.sym.Convolution(data=data, kernel=(3,3), no_bias=True, num_filter=1, num_group=1,
name="trt_bn_test_conv")
deconv = mx.sym.Deconvolution(data=conv, kernel=(3, 3), no_bias=True, num_filter=1,
num_group=1, name="trt_bn_test_deconv")
return deconv

def test_deconvolution_produce_same_output_as_tensorrt():
arg_params, aux_params = get_params()
arg_params_trt, aux_params_trt = get_params()

sym = get_symbol()
sym_trt = get_symbol().get_backend_symbol("TensorRT")

mx.contrib.tensorrt.init_tensorrt_params(sym_trt, arg_params_trt, aux_params_trt)

executor = sym.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), grad_req='null', force_rebind=True)
executor.copy_params_from(arg_params, aux_params)

executor_trt = sym_trt.simple_bind(ctx=mx.gpu(), data=(1, 1, 3, 3), grad_req='null',
force_rebind=True)
executor_trt.copy_params_from(arg_params_trt, aux_params_trt)

input_data = mx.nd.random.uniform(low=0, high=1, shape=(1, 1, 3, 3))

y = executor.forward(is_train=False, data=input_data)
y_trt = executor_trt.forward(is_train=False, data=input_data)

print(y[0].asnumpy())
print(y_trt[0].asnumpy())
assert_almost_equal(y[0].asnumpy(), y_trt[0].asnumpy(), 1e-4, 1e-4)

if __name__ == '__main__':
import nose
nose.runmodule()