From 657124599a9641ecb3affae607c80515c925741d Mon Sep 17 00:00:00 2001 From: grimoire Date: Sat, 11 Jun 2022 22:25:03 +0800 Subject: [PATCH 1/8] Add fuse select assign pass --- .../torchscript/optimizer/CMakeLists.txt | 3 +- .../torchscript/optimizer/bind.cpp | 5 + .../optimizer/ir/subgraph_matcher.cpp | 2 + .../optimizer/ir/subgraph_matcher.h | 4 +- .../passes/onnx/fuse_select_assign.cpp | 148 ++++++++++++++++++ .../passes/onnx/fuse_select_assign.h | 17 ++ mmdeploy/apis/onnx/passes/optimize_onnx.py | 1 + tests/test_apis/test_onnx_passes.py | 48 ++++++ 8 files changed, 226 insertions(+), 2 deletions(-) create mode 100644 csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp create mode 100644 csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h diff --git a/csrc/backend_ops/torchscript/optimizer/CMakeLists.txt b/csrc/backend_ops/torchscript/optimizer/CMakeLists.txt index ead1e61a5a..1b5e75ccca 100644 --- a/csrc/backend_ops/torchscript/optimizer/CMakeLists.txt +++ b/csrc/backend_ops/torchscript/optimizer/CMakeLists.txt @@ -3,6 +3,7 @@ project(ts_optimizer) find_package(Torch REQUIRED) +find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib") if (NOT TARGET pybind11) add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11 pybind11) endif () @@ -10,7 +11,7 @@ endif () file(GLOB_RECURSE OPTIMIZER_SRCS *.cpp) pybind11_add_module(${PROJECT_NAME} ${OPTIMIZER_SRCS}) -target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES}) +target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY}) target_link_directories(${PROJECT_NAME} PRIVATE mmdeploy::torchscript_ops) set_target_properties( ${PROJECT_NAME} PROPERTIES LIBRARY_OUTPUT_DIRECTORY diff --git a/csrc/backend_ops/torchscript/optimizer/bind.cpp b/csrc/backend_ops/torchscript/optimizer/bind.cpp index 21a691f141..660fa58d72 100644 --- a/csrc/backend_ops/torchscript/optimizer/bind.cpp +++ b/csrc/backend_ops/torchscript/optimizer/bind.cpp @@ -1,10 +1,13 @@ // Copyright (c) OpenMMLab. All rights reserved. #include +#include +#include #include #include "optimizer.h" #include "passes/onnx/flatten_cls_head.h" +#include "passes/onnx/fuse_select_assign.h" #include "passes/onnx/merge_shape_concate.h" #include "passes/onnx/onnx_peephole.h" @@ -33,6 +36,8 @@ PYBIND11_MODULE(ts_optimizer, m) { onnx_module.def("_jit_pass_merge_shape_concate", MergeShapeConcate, py::arg("graph")); onnx_module.def("_jit_pass_onnx_peephole", ONNXPeephole, py::arg("graph")); onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph")); + onnx_module.def("_jit_pass_fuse_select_assign", FuseSelectAssign, py::arg("graph"), + py::arg("params")); } } // namespace torch_jit diff --git a/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp index 97425aa5b3..d7df0704fc 100644 --- a/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp +++ b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.cpp @@ -295,6 +295,8 @@ bool SubgraphMatcher::SubgraphMatcherImpl::matchesSubgraphFromAnchorNode(Node* a SubgraphMatcher::SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute) : impl_(new SubgraphMatcher::SubgraphMatcherImpl(pattern, match_attribute)) {} +SubgraphMatcher::~SubgraphMatcher() = default; + bool SubgraphMatcher::matchesSubgraphFromAnchorNode(Node* anchor) { return impl_->matchesSubgraphFromAnchorNode(anchor); } diff --git a/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h index 6629b598ec..e2488e252c 100644 --- a/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h +++ b/csrc/backend_ops/torchscript/optimizer/ir/subgraph_matcher.h @@ -17,6 +17,8 @@ class SubgraphMatcher { public: explicit SubgraphMatcher(const Graph& pattern, MatchAttribute match_attribute = TRY_MATCH); + ~SubgraphMatcher(); + bool matchesSubgraphFromAnchorNode(Node* anchor); /** \brief Return match map for nodes. */ @@ -27,7 +29,7 @@ class SubgraphMatcher { private: class SubgraphMatcherImpl; - std::unique_ptr impl_ = nullptr; + std::unique_ptr impl_; }; } // namespace torch_jit diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp new file mode 100644 index 0000000000..2ff4d09f70 --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp @@ -0,0 +1,148 @@ +#include "fuse_select_assign.h" + +#include + +#include "../../ir/subgraph_matcher.h" +#include "torch/csrc/jit/ir/irparser.h" + +namespace mmdeploy { +namespace torch_jit { + +using c10::Symbol; +using torch::jit::Block; +using torch::jit::IValue; +using torch::jit::Node; + +bool FuseSelectAssign(Node* node, std::unordered_map& params, + std::unordered_map& vmap, SubgraphMatcher& matcher) { + auto values_map = matcher.values_map(); + + auto cmp1 = values_map[vmap["cmp_1"]]->node(); + auto cmp2 = values_map[vmap["cmp_2"]]->node(); + if (cmp1 != cmp2) { + // cmp_1 == cmp_2, cmp in (Great, Less) + if (cmp1->kind() != cmp2->kind()) return false; + if (!(cmp1->kind() == Symbol::onnx("Greater") || cmp1->kind() == Symbol::onnx("Less"))) + return false; + + // check threshold + Node* cmps[] = {cmp1, cmp2}; + float thres = 0.0f; + Node* x = nullptr; + for (int i = 0; i < 2; ++i) { + auto cmp = cmps[i]; + auto threshold = cmp->inputs()[1]->node(); + if (threshold->kind() != Symbol::onnx("Constant")) return false; + auto thres_val = threshold->t(Symbol::attr("value")); + if (i == 0) { + thres = thres_val.data_ptr()[0]; + x = cmp->inputs()[0]->node(); + } else { + float tmp_val = thres_val.data_ptr()[0]; + if (fabs(thres - tmp_val) > 1e-10) { + return false; + } + if (x != cmp->inputs()[0]->node()) { + return false; + } + } + } + } + + { + // check shape of reshape + Node* shape = values_map[vmap["reshape_1_shape"]]->node(); + auto shape_val = shape->t(Symbol::attr("value")); + if (shape_val.dim() != 1) return false; + if (shape_val.data_ptr()[0] != -1) return false; + } + + { + // check transpose + Node* trans[] = {values_map[vmap["trans_1"]]->node(), values_map[vmap["trans_2"]]->node()}; + for (auto tran : trans) { + auto tran_perm = tran->is(Symbol::attr("perm")); + if (tran_perm.size() != 2) return false; + if (tran_perm[0] != 1 || tran_perm[1] != 0) return false; + } + } + + { + // check gather indice + Node* gather_inds = values_map[vmap["gather_inds_2"]]->node(); + auto inds_val = gather_inds->t(Symbol::attr("value")); + if (inds_val.dim() != 0) return false; + if (inds_val.data_ptr()[0] != 0) return false; + } + + { + // check slice start + Node* slice = values_map[vmap["slice_2"]]->node(); + auto start_name = slice->inputs()[1]->debugName(); + auto start_val = params[start_name]; + if (start_val.dim() != 1) return false; + if (start_val.data_ptr()[0] != 0) return false; + } + + // create new node + auto graph = node->owningGraph(); + auto z = values_map[vmap["z"]]; + auto y = values_map[vmap["y"]]; + auto where_node = graph->create(Symbol::onnx("Where"), {cmp1->output(), z, y}); + where_node->insertBefore(node); + where_node->output()->copyMetadata(node->output()); + node->output()->replaceAllUsesWith(where_node->output()); + return true; +} + +void FuseSelectAssign(Block* block, std::unordered_map& params, + std::unordered_map& vmap, SubgraphMatcher& matcher) { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) { + auto node = *it; + ++it; + for (auto block : node->blocks()) { + FuseSelectAssign(block, params, vmap, matcher); + } + + if (matcher.matchesSubgraphFromAnchorNode(node)) { + FuseSelectAssign(node, params, vmap, matcher); + } + } +} + +void FuseSelectAssign(std::shared_ptr& graph, + std::unordered_map& params) { + std::string pattern_str = R"IR( + graph(%y, %z, %cmp_1, %cmp_2, %start, %axes): + %nz_1 = onnx::NonZero(%cmp_1) + %trans_1 = onnx::Transpose(%nz_1) + %gather_1 = onnx::GatherND(%z, %trans_1) + %reshape_1_shape = onnx::Constant() + %reshape_1 = onnx::Reshape(%gather_1, %reshape_1_shape) + %shape_2 = onnx::Shape(%y) + %expand_2 = onnx::Expand(%cmp_2, %shape_2) + %nz_2 = onnx::NonZero(%expand_2) + %trans_2 = onnx::Transpose(%nz_2) + %trans_shape_2 = onnx::Shape(%trans_2) + %gather_inds_2 = onnx::Constant() + %gather_2 = onnx::Gather(%trans_shape_2, %gather_inds_2) + %unsqueeze_2 = onnx::Unsqueeze(%gather_2) + %slice_2 = onnx::Slice(%reshape_1, %start, %unsqueeze_2, %axes) + %scatter_2 = onnx::ScatterND(%y, %trans_2, %slice_2) + return (%scatter_2) + )IR"; + + Graph pattern; + std::unordered_map vmap; + torch::jit::parseIR(pattern_str, &pattern, vmap); + + SubgraphMatcher matcher(pattern, MatchAttribute::NO_MATCH); + FuseSelectAssign(graph->block(), params, vmap, matcher); + torch::jit::EliminateDeadCode( + graph->block(), true, + torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); +} +} // namespace torch_jit +} // namespace mmdeploy diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h new file mode 100644 index 0000000000..afa0dc56d6 --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h @@ -0,0 +1,17 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef _FUSE_SELECT_ASSIGN_H_ +#define _FUSE_SELECT_ASSIGN_H_ + +#include +namespace mmdeploy { +namespace torch_jit { +using torch::Tensor; +using torch::jit::Graph; + +// this pass is used to fuse y[x>thres] = z[x>thres] +void FuseSelectAssign(std::shared_ptr& graph, + std::unordered_map& params); +} // namespace torch_jit +} // namespace mmdeploy + +#endif diff --git a/mmdeploy/apis/onnx/passes/optimize_onnx.py b/mmdeploy/apis/onnx/passes/optimize_onnx.py index d413a513ef..0997713a09 100644 --- a/mmdeploy/apis/onnx/passes/optimize_onnx.py +++ b/mmdeploy/apis/onnx/passes/optimize_onnx.py @@ -10,6 +10,7 @@ def optimize_onnx(graph, params_dict, torch_out): ts_optimizer.onnx._jit_pass_merge_shape_concate(graph) ts_optimizer.onnx._jit_pass_onnx_peephole(graph) ts_optimizer.onnx._jit_pass_flatten_cls_head(graph) + ts_optimizer.onnx._jit_pass_fuse_select_assign(graph, params_dict) except Exception: pass diff --git a/tests/test_apis/test_onnx_passes.py b/tests/test_apis/test_onnx_passes.py index c7dc891c5f..cd11972877 100644 --- a/tests/test_apis/test_onnx_passes.py +++ b/tests/test_apis/test_onnx_passes.py @@ -188,3 +188,51 @@ def forward(self, x): node, idx = _find_next_node(idx + 1, nodes, 'Flatten') assert node is not None + + +def test_fuse_select_assign(): + pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx') + + try: + from mmdeploy.backend.torchscript import ts_optimizer + opt_pass = ts_optimizer.onnx._jit_pass_fuse_select_assign + except ImportError: + pytest.skip('pass not found.') + + def _optimize_onnx(graph, params_dict, torch_out): + opt_pass(graph, params_dict) + return graph, params_dict, torch_out + + class TestModel(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + z = x / 2 + y = torch.zeros_like(x) + y[x < 0.5] = z[x < 0.5] + return y + + model = TestModel() + x = torch.rand(1, 4, 8, 8) + + with RewriterContext({}, onnx_custom_passes=_optimize_onnx): + torch.onnx.export( + model, + x, + onnx_file, + input_names=['input'], + output_names=['output'], + dynamic_axes=dict(input={ + 2: 'h', + 3: 'w' + }), + opset_version=11) + + onnx_model = onnx.load(onnx_file) + graph = onnx_model.graph + nodes = graph.node + + node, _ = _find_next_node(0, nodes, 'Where') + assert node is not None From dede413c4cc524932b4eba2a98a4b5b5f4438cb8 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 14 Jun 2022 19:05:32 +0800 Subject: [PATCH 2/8] move code to csrc --- .../torchscript/optimizer/passes/onnx/fuse_select_assign.cpp | 0 .../torchscript/optimizer/passes/onnx/fuse_select_assign.h | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename csrc/{ => mmdeploy}/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp (100%) rename csrc/{ => mmdeploy}/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h (100%) diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp similarity index 100% rename from csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp rename to csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h similarity index 100% rename from csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h rename to csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h From ca32fda01e300cac8c04c7bb1dc67ff2c9a6daf7 Mon Sep 17 00:00:00 2001 From: grimoire Date: Tue, 21 Jun 2022 15:37:22 +0800 Subject: [PATCH 3/8] add config flag --- configs/_base_/onnx_config.py | 3 ++- mmdeploy/apis/core/pipeline_manager.py | 6 ++++-- mmdeploy/apis/onnx/passes/optimize_onnx.py | 6 +++++- mmdeploy/apis/pytorch2onnx.py | 4 +++- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/configs/_base_/onnx_config.py b/configs/_base_/onnx_config.py index bf48e7ab77..43621b12b7 100644 --- a/configs/_base_/onnx_config.py +++ b/configs/_base_/onnx_config.py @@ -6,4 +6,5 @@ save_file='end2end.onnx', input_names=['input'], output_names=['output'], - input_shape=None) + input_shape=None, + optimize=True) diff --git a/mmdeploy/apis/core/pipeline_manager.py b/mmdeploy/apis/core/pipeline_manager.py index f46697a238..ab6df3cf37 100644 --- a/mmdeploy/apis/core/pipeline_manager.py +++ b/mmdeploy/apis/core/pipeline_manager.py @@ -76,8 +76,10 @@ def pop_mp_output(self, call_id: int = None) -> Any: """pop multiprocess output.""" assert self._mp_dict is not None, 'mp_dict is None.' call_id = self._call_id if call_id is None else call_id - assert call_id in self._mp_dict, \ - f'`{self._func_name}` with Call id: {call_id} failed.' + if call_id not in self._mp_dict: + get_root_logger().error( + f'`{self._func_name}` with Call id: {call_id} failed. exit.') + exit() ret = self._mp_dict[call_id] self._mp_dict.pop(call_id) return ret diff --git a/mmdeploy/apis/onnx/passes/optimize_onnx.py b/mmdeploy/apis/onnx/passes/optimize_onnx.py index 0997713a09..48b1e2933c 100644 --- a/mmdeploy/apis/onnx/passes/optimize_onnx.py +++ b/mmdeploy/apis/onnx/passes/optimize_onnx.py @@ -12,6 +12,10 @@ def optimize_onnx(graph, params_dict, torch_out): ts_optimizer.onnx._jit_pass_flatten_cls_head(graph) ts_optimizer.onnx._jit_pass_fuse_select_assign(graph, params_dict) except Exception: - pass + logger.warning( + 'Can not optimize model, please build torchscipt extension.\n' + 'More details: ' + 'https://github.com/open-mmlab/mmdeploy/blob/master/docs/en/experimental/onnx_optimizer.md' # noqa + ) return graph, params_dict, torch_out diff --git a/mmdeploy/apis/pytorch2onnx.py b/mmdeploy/apis/pytorch2onnx.py index 59647e89ae..4c1bdb58b7 100644 --- a/mmdeploy/apis/pytorch2onnx.py +++ b/mmdeploy/apis/pytorch2onnx.py @@ -82,6 +82,7 @@ def torch2onnx(img: Any, 'verbose', False) keep_initializers_as_inputs = onnx_cfg.get('keep_initializers_as_inputs', True) + optimize = onnx_cfg.get('optimize', False) with no_mp(): export( torch_model, @@ -94,4 +95,5 @@ def torch2onnx(img: Any, opset_version=opset_version, dynamic_axes=dynamic_axes, verbose=verbose, - keep_initializers_as_inputs=keep_initializers_as_inputs) + keep_initializers_as_inputs=keep_initializers_as_inputs, + optimize=optimize) From c1cc739224115c321b99e369942ad8e1914c0a32 Mon Sep 17 00:00:00 2001 From: grimoire Date: Sat, 11 Jun 2022 22:25:03 +0800 Subject: [PATCH 4/8] Add fuse select assign pass --- .../passes/onnx/fuse_select_assign.cpp | 148 ++++++++++++++++++ .../passes/onnx/fuse_select_assign.h | 17 ++ 2 files changed, 165 insertions(+) create mode 100644 csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp create mode 100644 csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp new file mode 100644 index 0000000000..2ff4d09f70 --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp @@ -0,0 +1,148 @@ +#include "fuse_select_assign.h" + +#include + +#include "../../ir/subgraph_matcher.h" +#include "torch/csrc/jit/ir/irparser.h" + +namespace mmdeploy { +namespace torch_jit { + +using c10::Symbol; +using torch::jit::Block; +using torch::jit::IValue; +using torch::jit::Node; + +bool FuseSelectAssign(Node* node, std::unordered_map& params, + std::unordered_map& vmap, SubgraphMatcher& matcher) { + auto values_map = matcher.values_map(); + + auto cmp1 = values_map[vmap["cmp_1"]]->node(); + auto cmp2 = values_map[vmap["cmp_2"]]->node(); + if (cmp1 != cmp2) { + // cmp_1 == cmp_2, cmp in (Great, Less) + if (cmp1->kind() != cmp2->kind()) return false; + if (!(cmp1->kind() == Symbol::onnx("Greater") || cmp1->kind() == Symbol::onnx("Less"))) + return false; + + // check threshold + Node* cmps[] = {cmp1, cmp2}; + float thres = 0.0f; + Node* x = nullptr; + for (int i = 0; i < 2; ++i) { + auto cmp = cmps[i]; + auto threshold = cmp->inputs()[1]->node(); + if (threshold->kind() != Symbol::onnx("Constant")) return false; + auto thres_val = threshold->t(Symbol::attr("value")); + if (i == 0) { + thres = thres_val.data_ptr()[0]; + x = cmp->inputs()[0]->node(); + } else { + float tmp_val = thres_val.data_ptr()[0]; + if (fabs(thres - tmp_val) > 1e-10) { + return false; + } + if (x != cmp->inputs()[0]->node()) { + return false; + } + } + } + } + + { + // check shape of reshape + Node* shape = values_map[vmap["reshape_1_shape"]]->node(); + auto shape_val = shape->t(Symbol::attr("value")); + if (shape_val.dim() != 1) return false; + if (shape_val.data_ptr()[0] != -1) return false; + } + + { + // check transpose + Node* trans[] = {values_map[vmap["trans_1"]]->node(), values_map[vmap["trans_2"]]->node()}; + for (auto tran : trans) { + auto tran_perm = tran->is(Symbol::attr("perm")); + if (tran_perm.size() != 2) return false; + if (tran_perm[0] != 1 || tran_perm[1] != 0) return false; + } + } + + { + // check gather indice + Node* gather_inds = values_map[vmap["gather_inds_2"]]->node(); + auto inds_val = gather_inds->t(Symbol::attr("value")); + if (inds_val.dim() != 0) return false; + if (inds_val.data_ptr()[0] != 0) return false; + } + + { + // check slice start + Node* slice = values_map[vmap["slice_2"]]->node(); + auto start_name = slice->inputs()[1]->debugName(); + auto start_val = params[start_name]; + if (start_val.dim() != 1) return false; + if (start_val.data_ptr()[0] != 0) return false; + } + + // create new node + auto graph = node->owningGraph(); + auto z = values_map[vmap["z"]]; + auto y = values_map[vmap["y"]]; + auto where_node = graph->create(Symbol::onnx("Where"), {cmp1->output(), z, y}); + where_node->insertBefore(node); + where_node->output()->copyMetadata(node->output()); + node->output()->replaceAllUsesWith(where_node->output()); + return true; +} + +void FuseSelectAssign(Block* block, std::unordered_map& params, + std::unordered_map& vmap, SubgraphMatcher& matcher) { + auto graph = block->owningGraph(); + auto it = block->nodes().begin(); + while (it != block->nodes().end()) { + auto node = *it; + ++it; + for (auto block : node->blocks()) { + FuseSelectAssign(block, params, vmap, matcher); + } + + if (matcher.matchesSubgraphFromAnchorNode(node)) { + FuseSelectAssign(node, params, vmap, matcher); + } + } +} + +void FuseSelectAssign(std::shared_ptr& graph, + std::unordered_map& params) { + std::string pattern_str = R"IR( + graph(%y, %z, %cmp_1, %cmp_2, %start, %axes): + %nz_1 = onnx::NonZero(%cmp_1) + %trans_1 = onnx::Transpose(%nz_1) + %gather_1 = onnx::GatherND(%z, %trans_1) + %reshape_1_shape = onnx::Constant() + %reshape_1 = onnx::Reshape(%gather_1, %reshape_1_shape) + %shape_2 = onnx::Shape(%y) + %expand_2 = onnx::Expand(%cmp_2, %shape_2) + %nz_2 = onnx::NonZero(%expand_2) + %trans_2 = onnx::Transpose(%nz_2) + %trans_shape_2 = onnx::Shape(%trans_2) + %gather_inds_2 = onnx::Constant() + %gather_2 = onnx::Gather(%trans_shape_2, %gather_inds_2) + %unsqueeze_2 = onnx::Unsqueeze(%gather_2) + %slice_2 = onnx::Slice(%reshape_1, %start, %unsqueeze_2, %axes) + %scatter_2 = onnx::ScatterND(%y, %trans_2, %slice_2) + return (%scatter_2) + )IR"; + + Graph pattern; + std::unordered_map vmap; + torch::jit::parseIR(pattern_str, &pattern, vmap); + + SubgraphMatcher matcher(pattern, MatchAttribute::NO_MATCH); + FuseSelectAssign(graph->block(), params, vmap, matcher); + torch::jit::EliminateDeadCode( + graph->block(), true, + torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); +} +} // namespace torch_jit +} // namespace mmdeploy diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h new file mode 100644 index 0000000000..afa0dc56d6 --- /dev/null +++ b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h @@ -0,0 +1,17 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef _FUSE_SELECT_ASSIGN_H_ +#define _FUSE_SELECT_ASSIGN_H_ + +#include +namespace mmdeploy { +namespace torch_jit { +using torch::Tensor; +using torch::jit::Graph; + +// this pass is used to fuse y[x>thres] = z[x>thres] +void FuseSelectAssign(std::shared_ptr& graph, + std::unordered_map& params); +} // namespace torch_jit +} // namespace mmdeploy + +#endif From f0b6930b0f1f8a97f865832d97edc6322d4486d9 Mon Sep 17 00:00:00 2001 From: grimoire Date: Sun, 26 Jun 2022 01:29:34 +0800 Subject: [PATCH 5/8] Add CSE for ONNX --- .../torchscript/optimizer/bind.cpp | 3 + .../onnx/common_subgraph_elimination.cpp | 138 ++++++++++++++++++ .../passes/onnx/common_subgraph_elimination.h | 20 +++ mmdeploy/apis/onnx/passes/optimize_onnx.py | 4 +- tests/test_apis/test_onnx_passes.py | 50 +++++++ 5 files changed, 214 insertions(+), 1 deletion(-) create mode 100644 csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp create mode 100644 csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp index 660fa58d72..3b8bb0f632 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/bind.cpp @@ -6,6 +6,7 @@ #include #include "optimizer.h" +#include "passes/onnx/common_subgraph_elimination.h" #include "passes/onnx/flatten_cls_head.h" #include "passes/onnx/fuse_select_assign.h" #include "passes/onnx/merge_shape_concate.h" @@ -38,6 +39,8 @@ PYBIND11_MODULE(ts_optimizer, m) { onnx_module.def("_jit_pass_flatten_cls_head", FlattenClsHead, py::arg("graph")); onnx_module.def("_jit_pass_fuse_select_assign", FuseSelectAssign, py::arg("graph"), py::arg("params")); + onnx_module.def("_jit_pass_common_subgraph_elimination", CommonSubgraphElimination, + py::arg("graph"), py::arg("params")); } } // namespace torch_jit diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp new file mode 100644 index 0000000000..c6541e630a --- /dev/null +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.cpp @@ -0,0 +1,138 @@ +// https://github.com/pytorch/pytorch/blob/v1.8.1/torch/csrc/jit/passes/common_subexpression_elimination.cpp +#include "common_subgraph_elimination.h" + +#include +#include + +namespace mmdeploy { +namespace torch_jit { + +using c10::Symbol; +using torch::jit::Block; +using torch::jit::EqualNode; +using torch::jit::HashNode; +using torch::jit::Node; +using torch::jit::Value; + +struct EqualNodeWithParams { + EqualNodeWithParams(std::unordered_map& params) : params_(params) {} + + bool operator()(const Node* lhs, const Node* rhs) const { + auto lhs_inputs = lhs->inputs(); + auto rhs_inputs = rhs->inputs(); + } + + private: + std::unordered_map& params_; +}; + +struct CommonSubexpressionEliminator { + using ParamMapType = std::unordered_map>; + CommonSubexpressionEliminator(std::shared_ptr graph, + std::unordered_map& params) + : graph_(std::move(graph)), params_(params) {} + + bool run(std::function parent_lookup_fn) { + ParamMapType param_map; + return run(graph_->block(), std::move(parent_lookup_fn), param_map); + } + + // The function implements common subexpression elimination. + // Since the nodes are visited in topological order, one pass is enough. + // returns true if CSE made changes to a graph + bool run(Block* block, std::function parent_lookup_fn, ParamMapType& param_map) { + std::unordered_set subexprs; + bool changed = false; + for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { + auto node = *it; + + // check if inputs come from params(graph input) + auto node_inputs = node->inputs(); + for (auto input : node_inputs) { + if (input->node()->kind() == Symbol::fromQualString("prim::Param")) { + auto debug_name = input->debugName(); + + // check if input in params_ + if (params_.find(debug_name) == params_.end()) continue; + + // check if input is already visited. + if (param_map.find(debug_name) != param_map.end()) continue; + + // check if there is a param has same value with input + auto val = params_[debug_name]; + bool update_map = true; + for (auto kv : param_map) { + auto param_val = kv.second.first; + if (val.device() != param_val.device()) continue; + if (val.dtype() != param_val.dtype()) continue; + if (!val.equal(param_val)) continue; + input->replaceAllUsesWith(kv.second.second); + update_map = false; + break; + } + + // add input to param_map + if (update_map) { + param_map.emplace(debug_name, + std::make_pair(std::move(val), std::move(input))); + } + } + } + + if (!node->blocks().empty()) { + // Traverse sub-blocks. + for (auto block : node->blocks()) { + changed |= run( + block, + [&](Node* n) { + auto existing = subexprs.find(n); + if (existing != subexprs.end()) { + return *existing; + } + + return parent_lookup_fn(n); + }, + param_map); + } + + continue; + } + + // Check for CSE opportunities in the parent block. + auto parent_lookup = parent_lookup_fn(node); + auto g_out = node->owningGraph()->outputs(); + if (parent_lookup != nullptr) { + changed = true; + node->replaceAllUsesWith(parent_lookup); + it.destroyCurrent(); + continue; + } + + // Check whether the same subexpression already exists. + auto subit = subexprs.insert(node); + if (!subit.second) { + // Subexpression exists, replace the uses of node, and destroy it. + auto existing = *subit.first; + + changed = true; + node->replaceAllUsesWith(existing); + // Destroy the node. + it.destroyCurrent(); + } + } + + return changed; + } + + private: + std::shared_ptr graph_; + std::unordered_map& params_; +}; + +void CommonSubgraphElimination(std::shared_ptr& graph, + std::unordered_map& params) { + CommonSubexpressionEliminator cse(graph, params); + cse.run([](Node*) { return nullptr; }); +} +} // namespace torch_jit +} // namespace mmdeploy diff --git a/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h new file mode 100644 index 0000000000..d90b98073e --- /dev/null +++ b/csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/common_subgraph_elimination.h @@ -0,0 +1,20 @@ +// Copyright (c) OpenMMLab. All rights reserved. +#ifndef _COMMON_SUBGRAPH_ELIMINATION_H_ +#define _COMMON_SUBGRAPH_ELIMINATION_H_ + +#include +namespace mmdeploy { +namespace torch_jit { +using torch::Tensor; +using torch::jit::Graph; + +// This pass is used eliminate the common subgraph. +// There are two main difference between the one in torch/csrc/jit/pass +// 1. AliasDb is not needed in ONNX model +// 2. params might also participated in the elimination +void CommonSubgraphElimination(std::shared_ptr& graph, + std::unordered_map& params); +} // namespace torch_jit +} // namespace mmdeploy + +#endif diff --git a/mmdeploy/apis/onnx/passes/optimize_onnx.py b/mmdeploy/apis/onnx/passes/optimize_onnx.py index 48b1e2933c..9afc197bdc 100644 --- a/mmdeploy/apis/onnx/passes/optimize_onnx.py +++ b/mmdeploy/apis/onnx/passes/optimize_onnx.py @@ -11,7 +11,9 @@ def optimize_onnx(graph, params_dict, torch_out): ts_optimizer.onnx._jit_pass_onnx_peephole(graph) ts_optimizer.onnx._jit_pass_flatten_cls_head(graph) ts_optimizer.onnx._jit_pass_fuse_select_assign(graph, params_dict) - except Exception: + ts_optimizer.onnx._jit_pass_common_subgraph_elimination( + graph, params_dict) + except ImportError: logger.warning( 'Can not optimize model, please build torchscipt extension.\n' 'More details: ' diff --git a/tests/test_apis/test_onnx_passes.py b/tests/test_apis/test_onnx_passes.py index cd11972877..420ea2572f 100644 --- a/tests/test_apis/test_onnx_passes.py +++ b/tests/test_apis/test_onnx_passes.py @@ -236,3 +236,53 @@ def forward(self, x): node, _ = _find_next_node(0, nodes, 'Where') assert node is not None + + +def test_common_subgraph_elimination(): + pytest.importorskip('mmdeploy.backend.torchscript.ts_optimizer.onnx') + + try: + from mmdeploy.backend.torchscript import ts_optimizer + opt_pass = ts_optimizer.onnx._jit_pass_common_subgraph_elimination + except ImportError: + pytest.skip('pass not found.') + + def _optimize_onnx(graph, params_dict, torch_out): + opt_pass(graph, params_dict) + return graph, params_dict, torch_out + + class TestModel(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + y = x.unsqueeze(1) + z = x.unsqueeze(1) + return y + z + + model = TestModel() + x = torch.rand(1, 2, 3) + + with RewriterContext({}, onnx_custom_passes=_optimize_onnx): + torch.onnx.export( + model, + x, + onnx_file, + input_names=['input'], + output_names=['output'], + dynamic_axes=dict(input={ + 1: 'h', + 2: 'w' + }), + opset_version=11) + + onnx_model = onnx.load(onnx_file) + graph = onnx_model.graph + nodes = graph.node + + unsqueeze_count = 0 + for n in nodes: + if n.op_type == 'Unsqueeze': + unsqueeze_count += 1 + assert unsqueeze_count == 1 From 3debca8bc8b23eab89161fddf8a8a867260a6128 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 1 Jul 2022 12:22:03 +0800 Subject: [PATCH 6/8] remove useless code --- .../passes/onnx/fuse_select_assign.cpp | 148 ------------------ .../passes/onnx/fuse_select_assign.h | 17 -- 2 files changed, 165 deletions(-) delete mode 100644 csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp delete mode 100644 csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp deleted file mode 100644 index 2ff4d09f70..0000000000 --- a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.cpp +++ /dev/null @@ -1,148 +0,0 @@ -#include "fuse_select_assign.h" - -#include - -#include "../../ir/subgraph_matcher.h" -#include "torch/csrc/jit/ir/irparser.h" - -namespace mmdeploy { -namespace torch_jit { - -using c10::Symbol; -using torch::jit::Block; -using torch::jit::IValue; -using torch::jit::Node; - -bool FuseSelectAssign(Node* node, std::unordered_map& params, - std::unordered_map& vmap, SubgraphMatcher& matcher) { - auto values_map = matcher.values_map(); - - auto cmp1 = values_map[vmap["cmp_1"]]->node(); - auto cmp2 = values_map[vmap["cmp_2"]]->node(); - if (cmp1 != cmp2) { - // cmp_1 == cmp_2, cmp in (Great, Less) - if (cmp1->kind() != cmp2->kind()) return false; - if (!(cmp1->kind() == Symbol::onnx("Greater") || cmp1->kind() == Symbol::onnx("Less"))) - return false; - - // check threshold - Node* cmps[] = {cmp1, cmp2}; - float thres = 0.0f; - Node* x = nullptr; - for (int i = 0; i < 2; ++i) { - auto cmp = cmps[i]; - auto threshold = cmp->inputs()[1]->node(); - if (threshold->kind() != Symbol::onnx("Constant")) return false; - auto thres_val = threshold->t(Symbol::attr("value")); - if (i == 0) { - thres = thres_val.data_ptr()[0]; - x = cmp->inputs()[0]->node(); - } else { - float tmp_val = thres_val.data_ptr()[0]; - if (fabs(thres - tmp_val) > 1e-10) { - return false; - } - if (x != cmp->inputs()[0]->node()) { - return false; - } - } - } - } - - { - // check shape of reshape - Node* shape = values_map[vmap["reshape_1_shape"]]->node(); - auto shape_val = shape->t(Symbol::attr("value")); - if (shape_val.dim() != 1) return false; - if (shape_val.data_ptr()[0] != -1) return false; - } - - { - // check transpose - Node* trans[] = {values_map[vmap["trans_1"]]->node(), values_map[vmap["trans_2"]]->node()}; - for (auto tran : trans) { - auto tran_perm = tran->is(Symbol::attr("perm")); - if (tran_perm.size() != 2) return false; - if (tran_perm[0] != 1 || tran_perm[1] != 0) return false; - } - } - - { - // check gather indice - Node* gather_inds = values_map[vmap["gather_inds_2"]]->node(); - auto inds_val = gather_inds->t(Symbol::attr("value")); - if (inds_val.dim() != 0) return false; - if (inds_val.data_ptr()[0] != 0) return false; - } - - { - // check slice start - Node* slice = values_map[vmap["slice_2"]]->node(); - auto start_name = slice->inputs()[1]->debugName(); - auto start_val = params[start_name]; - if (start_val.dim() != 1) return false; - if (start_val.data_ptr()[0] != 0) return false; - } - - // create new node - auto graph = node->owningGraph(); - auto z = values_map[vmap["z"]]; - auto y = values_map[vmap["y"]]; - auto where_node = graph->create(Symbol::onnx("Where"), {cmp1->output(), z, y}); - where_node->insertBefore(node); - where_node->output()->copyMetadata(node->output()); - node->output()->replaceAllUsesWith(where_node->output()); - return true; -} - -void FuseSelectAssign(Block* block, std::unordered_map& params, - std::unordered_map& vmap, SubgraphMatcher& matcher) { - auto graph = block->owningGraph(); - auto it = block->nodes().begin(); - while (it != block->nodes().end()) { - auto node = *it; - ++it; - for (auto block : node->blocks()) { - FuseSelectAssign(block, params, vmap, matcher); - } - - if (matcher.matchesSubgraphFromAnchorNode(node)) { - FuseSelectAssign(node, params, vmap, matcher); - } - } -} - -void FuseSelectAssign(std::shared_ptr& graph, - std::unordered_map& params) { - std::string pattern_str = R"IR( - graph(%y, %z, %cmp_1, %cmp_2, %start, %axes): - %nz_1 = onnx::NonZero(%cmp_1) - %trans_1 = onnx::Transpose(%nz_1) - %gather_1 = onnx::GatherND(%z, %trans_1) - %reshape_1_shape = onnx::Constant() - %reshape_1 = onnx::Reshape(%gather_1, %reshape_1_shape) - %shape_2 = onnx::Shape(%y) - %expand_2 = onnx::Expand(%cmp_2, %shape_2) - %nz_2 = onnx::NonZero(%expand_2) - %trans_2 = onnx::Transpose(%nz_2) - %trans_shape_2 = onnx::Shape(%trans_2) - %gather_inds_2 = onnx::Constant() - %gather_2 = onnx::Gather(%trans_shape_2, %gather_inds_2) - %unsqueeze_2 = onnx::Unsqueeze(%gather_2) - %slice_2 = onnx::Slice(%reshape_1, %start, %unsqueeze_2, %axes) - %scatter_2 = onnx::ScatterND(%y, %trans_2, %slice_2) - return (%scatter_2) - )IR"; - - Graph pattern; - std::unordered_map vmap; - torch::jit::parseIR(pattern_str, &pattern, vmap); - - SubgraphMatcher matcher(pattern, MatchAttribute::NO_MATCH); - FuseSelectAssign(graph->block(), params, vmap, matcher); - torch::jit::EliminateDeadCode( - graph->block(), true, - torch::jit::DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); -} -} // namespace torch_jit -} // namespace mmdeploy diff --git a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h b/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h deleted file mode 100644 index afa0dc56d6..0000000000 --- a/csrc/backend_ops/torchscript/optimizer/passes/onnx/fuse_select_assign.h +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) OpenMMLab. All rights reserved. -#ifndef _FUSE_SELECT_ASSIGN_H_ -#define _FUSE_SELECT_ASSIGN_H_ - -#include -namespace mmdeploy { -namespace torch_jit { -using torch::Tensor; -using torch::jit::Graph; - -// this pass is used to fuse y[x>thres] = z[x>thres] -void FuseSelectAssign(std::shared_ptr& graph, - std::unordered_map& params); -} // namespace torch_jit -} // namespace mmdeploy - -#endif From c8724491e67b30cf2996af070c1bcb485e68e6b5 Mon Sep 17 00:00:00 2001 From: grimoire Date: Fri, 1 Jul 2022 14:41:20 +0800 Subject: [PATCH 7/8] Install optimizer by setup tools --- MANIFEST.in | 3 + .../backend_ops/torchscript/CMakeLists.txt | 1 - mmdeploy/apis/onnx/optimizer.py | 2 +- mmdeploy/apis/onnx/passes/optimize_onnx.py | 3 +- setup.py | 75 ++++++++++++++++++- tests/test_apis/test_onnx_passes.py | 10 +-- tools/package_tools/mmdeploy_builder.py | 6 ++ 7 files changed, 90 insertions(+), 10 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 7c85a3240b..f3427de2fe 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,3 +5,6 @@ include mmdeploy/backend/ncnn/*.pyd include mmdeploy/lib/*.so include mmdeploy/lib/*.dll include mmdeploy/lib/*.pyd +include mmdeploy/backend/torchscript/*.so +include mmdeploy/backend/torchscript/*.dll +include mmdeploy/backend/torchscript/*.pyd diff --git a/csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt b/csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt index 8d862b9411..4b080f621a 100644 --- a/csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt +++ b/csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt @@ -1,4 +1,3 @@ # Copyright (c) OpenMMLab. All rights reserved. add_subdirectory(ops) -add_subdirectory(optimizer) diff --git a/mmdeploy/apis/onnx/optimizer.py b/mmdeploy/apis/onnx/optimizer.py index 612e9d8ea8..07b89d1eaf 100644 --- a/mmdeploy/apis/onnx/optimizer.py +++ b/mmdeploy/apis/onnx/optimizer.py @@ -15,7 +15,7 @@ def model_to_graph__custom_optimizer(ctx, *args, **kwargs): assert isinstance( custom_passes, Callable ), f'Expect a callable onnx_custom_passes, get {type(custom_passes)}.' - graph, params_dict, torch_out = custom_passes(graph, params_dict, + graph, params_dict, torch_out = custom_passes(ctx, graph, params_dict, torch_out) return graph, params_dict, torch_out diff --git a/mmdeploy/apis/onnx/passes/optimize_onnx.py b/mmdeploy/apis/onnx/passes/optimize_onnx.py index 8b12c6bf92..19e14bc292 100644 --- a/mmdeploy/apis/onnx/passes/optimize_onnx.py +++ b/mmdeploy/apis/onnx/passes/optimize_onnx.py @@ -2,7 +2,8 @@ from mmdeploy.utils import get_root_logger -def optimize_onnx(graph, params_dict, torch_out): +def optimize_onnx(ctx, graph, params_dict, torch_out): + """The optimize callback of the onnx model.""" logger = get_root_logger() logger.info('Execute onnx optimize passes.') try: diff --git a/setup.py b/setup.py index 86e5cdf022..ce175905d5 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,12 @@ from setuptools import find_packages, setup +try: + from torch.utils.cpp_extension import BuildExtension + cmd_class = {'build_ext': BuildExtension} +except ModuleNotFoundError: + cmd_class = {} + print('Skip building ext ops due to the absence of torch.') pwd = os.path.dirname(__file__) version_file = 'mmdeploy/version.py' @@ -96,6 +102,71 @@ def gen_packages_items(): return packages +def get_extensions(): + extensions = [] + ext_name = 'mmdeploy.backend.torchscript.ts_optimizer' + import glob + import platform + + from torch.utils.cpp_extension import CppExtension + + try: + import psutil + num_cpu = len(psutil.Process().cpu_affinity()) + cpu_use = max(4, num_cpu - 1) + except (ModuleNotFoundError, AttributeError): + cpu_use = 4 + + os.environ.setdefault('MAX_JOBS', str(cpu_use)) + define_macros = [] + + # Before PyTorch1.8.0, when compiling CUDA code, `cxx` is a + # required key passed to PyTorch. Even if there is no flag passed + # to cxx, users also need to pass an empty list to PyTorch. + # Since PyTorch1.8.0, it has a default value so users do not need + # to pass an empty list anymore. + # More details at https://github.com/pytorch/pytorch/pull/45956 + extra_compile_args = {'cxx': []} + + # c++14 is required. + # However, in the windows environment, some standard libraries + # will depend on c++17 or higher. In fact, for the windows + # environment, the compiler will choose the appropriate compiler + # to compile those cpp files, so there is no need to add the + # argument + if platform.system() != 'Windows': + extra_compile_args['cxx'] = ['-std=c++14'] + + include_dirs = [] + + op_files = glob.glob( + './csrc/mmdeploy/backend_ops/torchscript/optimizer/*.cpp' + ) + glob.glob( + './csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/*.cpp' + ) + glob.glob( + './csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/*.cpp') + extension = CppExtension + # include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) + + # c++14 is required. + # However, in the windows environment, some standard libraries + # will depend on c++17 or higher. In fact, for the windows + # environment, the compiler will choose the appropriate compiler + # to compile those cpp files, so there is no need to add the + # argument + if 'nvcc' in extra_compile_args and platform.system() != 'Windows': + extra_compile_args['nvcc'] += ['-std=c++14'] + + ext_ops = extension( + name=ext_name, + sources=op_files, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args) + extensions.append(ext_ops) + return extensions + + if __name__ == '__main__': setup( name='mmdeploy', @@ -128,6 +199,6 @@ def gen_packages_items(): 'build': parse_requirements('requirements/build.txt'), 'optional': parse_requirements('requirements/optional.txt'), }, - ext_modules=[], - cmdclass={}, + ext_modules=get_extensions(), + cmdclass=cmd_class, zip_safe=False) diff --git a/tests/test_apis/test_onnx_passes.py b/tests/test_apis/test_onnx_passes.py index 420ea2572f..a2d77b4463 100644 --- a/tests/test_apis/test_onnx_passes.py +++ b/tests/test_apis/test_onnx_passes.py @@ -30,7 +30,7 @@ def test_merge_shape_concate(): except ImportError: pytest.skip('pass not found.') - def _optimize_onnx(graph, params_dict, torch_out): + def _optimize_onnx(ctx, graph, params_dict, torch_out): opt_pass(graph) return graph, params_dict, torch_out @@ -82,7 +82,7 @@ def test_peephole(): except ImportError: pytest.skip('pass not found.') - def _optimize_onnx(graph, params_dict, torch_out): + def _optimize_onnx(ctx, graph, params_dict, torch_out): opt_pass(graph) return graph, params_dict, torch_out @@ -148,7 +148,7 @@ def test_flatten_cls_head(): except ImportError: pytest.skip('pass not found.') - def _optimize_onnx(graph, params_dict, torch_out): + def _optimize_onnx(ctx, graph, params_dict, torch_out): opt_pass(graph) return graph, params_dict, torch_out @@ -199,7 +199,7 @@ def test_fuse_select_assign(): except ImportError: pytest.skip('pass not found.') - def _optimize_onnx(graph, params_dict, torch_out): + def _optimize_onnx(ctx, graph, params_dict, torch_out): opt_pass(graph, params_dict) return graph, params_dict, torch_out @@ -247,7 +247,7 @@ def test_common_subgraph_elimination(): except ImportError: pytest.skip('pass not found.') - def _optimize_onnx(graph, params_dict, torch_out): + def _optimize_onnx(ctx, graph, params_dict, torch_out): opt_pass(graph, params_dict) return graph, params_dict, torch_out diff --git a/tools/package_tools/mmdeploy_builder.py b/tools/package_tools/mmdeploy_builder.py index 1e3125ec75..b748d8d756 100644 --- a/tools/package_tools/mmdeploy_builder.py +++ b/tools/package_tools/mmdeploy_builder.py @@ -131,6 +131,12 @@ def _remove_in_mmdeploy(path): for ncnn_ext_path in ncnn_ext_paths: os.remove(ncnn_ext_path) + # remove ts_optmizer + ts_optimizer_paths = glob( + osp.join(mmdeploy_dir, 'mmdeploy/backend/ncnn/ts_optimizer.*')) + for ts_optimizer_path in ts_optimizer_paths: + os.remove(ts_optimizer_path) + def build_mmdeploy(cfg, mmdeploy_dir, dist_dir=None): cmake_flags = cfg.get('cmake_flags', []) From 9de850c1c36050905441596c488212992055c93b Mon Sep 17 00:00:00 2001 From: grimoire Date: Mon, 25 Jul 2022 10:56:49 +0800 Subject: [PATCH 8/8] fix comment --- setup.py | 1 - tools/package_tools/mmdeploy_builder.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index ce175905d5..634423a5ae 100644 --- a/setup.py +++ b/setup.py @@ -146,7 +146,6 @@ def get_extensions(): ) + glob.glob( './csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/*.cpp') extension = CppExtension - # include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) # c++14 is required. # However, in the windows environment, some standard libraries diff --git a/tools/package_tools/mmdeploy_builder.py b/tools/package_tools/mmdeploy_builder.py index a873982c61..bdf70ed2e3 100644 --- a/tools/package_tools/mmdeploy_builder.py +++ b/tools/package_tools/mmdeploy_builder.py @@ -135,7 +135,7 @@ def _remove_in_mmdeploy(path): # remove ts_optmizer ts_optimizer_paths = glob( - osp.join(mmdeploy_dir, 'mmdeploy/backend/ncnn/ts_optimizer.*')) + osp.join(mmdeploy_dir, 'mmdeploy/backend/torchscript/ts_optimizer.*')) for ts_optimizer_path in ts_optimizer_paths: os.remove(ts_optimizer_path)