From 9a07322da538b3870325f72dcee5b6e04e392fe6 Mon Sep 17 00:00:00 2001
From: Serge Panev <spanev@nvidia.com>
Date: Wed, 6 May 2020 21:30:53 -0700
Subject: [PATCH 1/4] Update to TRT 7 API

Signed-off-by: Serge Panev <spanev@nvidia.com>
---
 src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc | 4 +++-
 src/operator/subgraph/tensorrt/tensorrt.cu         | 3 +--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
index b02d1094183f..4c6bc7d101fb 100644
--- a/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
+++ b/src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc
@@ -78,7 +78,9 @@ std::tuple<unique_ptr<nvinfer1::ICudaEngine>,
 
   auto trt_logger = std::unique_ptr<TRT_Logger>(new TRT_Logger(verbosity));
   auto trt_builder = InferObject(nvinfer1::createInferBuilder(*trt_logger));
-  auto trt_network = InferObject(trt_builder->createNetwork());
+  const auto explicitBatch = 1U << static_cast<uint32_t>(
+                             nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
+  auto trt_network = InferObject(trt_builder->createNetworkV2(explicitBatch));
   auto trt_parser  = InferObject(nvonnxparser::createParser(*trt_network, *trt_logger));
   ::ONNX_NAMESPACE::ModelProto parsed_model;
   // We check for a valid parse, but the main effect is the side effect
diff --git a/src/operator/subgraph/tensorrt/tensorrt.cu b/src/operator/subgraph/tensorrt/tensorrt.cu
index 4a5b23b3a9f7..80429adb43c5 100644
--- a/src/operator/subgraph/tensorrt/tensorrt.cu
+++ b/src/operator/subgraph/tensorrt/tensorrt.cu
@@ -56,8 +56,7 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx,
       param.bindings->at(i) = outputs[p.first].dptr_;
     }
   }
-  const int batch_size = static_cast<int>(inputs[0].shape_[0]);
-  param.trt_executor->enqueue(batch_size, param.bindings->data(), cuda_s, nullptr);
+  param.trt_executor->enqueueV2(param.bindings->data(), cuda_s, nullptr);
 }
 
 NNVM_REGISTER_OP(_TensorRT)

From 1ccd7dd8bb89742bc000950484b1d9549d61bbac Mon Sep 17 00:00:00 2001
From: Serge Panev <spanev@nvidia.com>
Date: Fri, 5 Jun 2020 03:02:33 -0700
Subject: [PATCH 2/4] Add PrePartition param caching - move
 init_tensorrt_params logic

Signed-off-by: Serge Panev <spanev@nvidia.com>
---
 src/operator/subgraph/build_subgraph.cc       |  2 +-
 src/operator/subgraph/tensorrt/tensorrt-inl.h | 45 ++++++++++++++++---
 src/operator/subgraph/tensorrt/tensorrt.cu    |  3 +-
 3 files changed, 43 insertions(+), 7 deletions(-)

diff --git a/src/operator/subgraph/build_subgraph.cc b/src/operator/subgraph/build_subgraph.cc
index 413395c3b74f..c2fdfacf4d28 100644
--- a/src/operator/subgraph/build_subgraph.cc
+++ b/src/operator/subgraph/build_subgraph.cc
@@ -428,7 +428,7 @@ void SortEntries(const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry
 }
 
 /*!
- * \brief Given a subgraph, find the output entries of a subgraph.
+ * \brief Given a subgraph, find the input entries of a subgraph.
  * \param g pointer to the whole graph
  * \param simple_nods vector of simple nodes in top sorted order
  * \param subgraph_nodes vector of pointers of simples of a subgraph.
diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h b/src/operator/subgraph/tensorrt/tensorrt-inl.h
index 16cc13006d59..d20f82f348e1 100644
--- a/src/operator/subgraph/tensorrt/tensorrt-inl.h
+++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h
@@ -267,6 +267,24 @@ class TensorrtProperty : public SubgraphProperty {
     return std::make_shared<TensorrtProperty>();
   }
 
+  void PrePartition(const nnvm::Graph& g,
+    const std::vector<std::pair<std::string, std::string>>& options_map) override {
+    auto& in_arg_names = g.GetAttr<std::vector<std::string>>("in_arg_names");
+    auto& in_aux_names = g.GetAttr<std::vector<std::string>>("in_aux_names");
+    NDArray **in_args_ptr = g.GetAttr<NDArray**>("in_args");
+    NDArray **in_aux_ptr = g.GetAttr<NDArray**>("in_aux");
+    // should we check if not empty?
+    in_args_dict.clear();
+    in_aux_dict.clear();
+    // we trust the Python API, len(in_arg_names) == len(in_args_ptr)
+    for (unsigned i = 0; i < in_arg_names.size(); ++i) {
+      in_args_dict[in_arg_names[i]] = in_args_ptr[i];
+    }
+    for (unsigned i = 0; i < in_aux_names.size(); ++i) {
+      in_aux_dict[in_aux_names[i]] = in_aux_ptr[i];
+    }
+  }
+
   nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
                                    const int subgraph_id) const override {
     nnvm::ObjectPtr n = nnvm::Node::Create();
@@ -280,16 +298,31 @@ class TensorrtProperty : public SubgraphProperty {
     n->attrs.op = Op::Get("_TensorRT");
     CHECK(n->attrs.op);
     n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
+
+    // Mapping subgraph params with NDArrays
+    TRTParam param;
     std::ostringstream params_oss;
-    for (auto &e : new_sym.ListInputNames(nnvm::Symbol::kAll)) {
-      params_oss << e << ";";
+    for (auto &param_name : new_sym.ListInputNames(nnvm::Symbol::kAll)) {
+      NDArray *cache;
+      auto it_args = in_args_dict.find(param_name);
+      if (it_args != in_args_dict.end()) {
+        cache = it_args->second;
+      } else {
+        auto it_aux = in_aux_dict.find(param_name);
+        if (it_aux != in_aux_dict.end()) {
+          cache = it_aux->second;
+        }
+      }
+      if (cache != nullptr) {
+        param.params_map.emplace(param_name, cache->Copy(Context()));
+        param.params_map[param_name].WaitToRead();
+        params_oss << param_name << ";";
+      }
     }
     auto tensorrt_params_names = params_oss.str();
     tensorrt_params_names.pop_back();
-    n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
-    TRTParam param;
     n->attrs.parsed = param;
-    n->op()->attr_parser(&(n->attrs));
+    n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
     return n;
   }
 
@@ -328,6 +361,8 @@ class TensorrtProperty : public SubgraphProperty {
     }
     subgraph_node->attrs.parsed = std::move(_params);
   }
+
+  std::unordered_map<std::string, NDArray*> in_args_dict, in_aux_dict;
 };
 
 
diff --git a/src/operator/subgraph/tensorrt/tensorrt.cu b/src/operator/subgraph/tensorrt/tensorrt.cu
index 80429adb43c5..826f9a5876b6 100644
--- a/src/operator/subgraph/tensorrt/tensorrt.cu
+++ b/src/operator/subgraph/tensorrt/tensorrt.cu
@@ -60,7 +60,8 @@ void TRTCompute(const OpStatePtr& state, const OpContext& ctx,
 }
 
 NNVM_REGISTER_OP(_TensorRT)
-.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute);
+.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", TRTCompute)
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);
 
 }  // namespace op
 }  // namespace mxnet

From 7a2eecd922abe5552ef9aaac4831cf9454d2e83a Mon Sep 17 00:00:00 2001
From: Serge Panev <spanev@nvidia.com>
Date: Wed, 10 Jun 2020 01:36:38 -0700
Subject: [PATCH 3/4] Handle node with no defined input

Signed-off-by: Serge Panev <spanev@nvidia.com>
---
 src/operator/subgraph/tensorrt/tensorrt-inl.h | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h b/src/operator/subgraph/tensorrt/tensorrt-inl.h
index d20f82f348e1..8da038862ca9 100644
--- a/src/operator/subgraph/tensorrt/tensorrt-inl.h
+++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h
@@ -320,7 +320,9 @@ class TensorrtProperty : public SubgraphProperty {
       }
     }
     auto tensorrt_params_names = params_oss.str();
-    tensorrt_params_names.pop_back();
+    if (!tensorrt_params_names.empty()) {
+      tensorrt_params_names.pop_back();
+    }
     n->attrs.parsed = param;
     n->attrs.dict["subgraph_params_names"] = tensorrt_params_names;
     return n;

From c7bfeebe38f4b28ea1683503a97dd15c9d3175b2 Mon Sep 17 00:00:00 2001
From: Serge Panev <spanev@nvidia.com>
Date: Tue, 30 Jun 2020 17:52:21 -0700
Subject: [PATCH 4/4] Remove tmp comment

Signed-off-by: Serge Panev <spanev@nvidia.com>
---
 src/operator/subgraph/tensorrt/tensorrt-inl.h | 1 -
 1 file changed, 1 deletion(-)

diff --git a/src/operator/subgraph/tensorrt/tensorrt-inl.h b/src/operator/subgraph/tensorrt/tensorrt-inl.h
index 8da038862ca9..b35a1715000e 100644
--- a/src/operator/subgraph/tensorrt/tensorrt-inl.h
+++ b/src/operator/subgraph/tensorrt/tensorrt-inl.h
@@ -273,7 +273,6 @@ class TensorrtProperty : public SubgraphProperty {
     auto& in_aux_names = g.GetAttr<std::vector<std::string>>("in_aux_names");
     NDArray **in_args_ptr = g.GetAttr<NDArray**>("in_args");
     NDArray **in_aux_ptr = g.GetAttr<NDArray**>("in_aux");
-    // should we check if not empty?
     in_args_dict.clear();
     in_aux_dict.clear();
     // we trust the Python API, len(in_arg_names) == len(in_args_ptr)