Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Install Optimizer by setuptools #690

Merged
merged 13 commits into from
Jul 25, 2022
3 changes: 3 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,6 @@ include mmdeploy/backend/ncnn/*.pyd
include mmdeploy/lib/*.so
include mmdeploy/lib/*.dll
include mmdeploy/lib/*.pyd
include mmdeploy/backend/torchscript/*.so
include mmdeploy/backend/torchscript/*.dll
include mmdeploy/backend/torchscript/*.pyd
1 change: 0 additions & 1 deletion csrc/mmdeploy/backend_ops/torchscript/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Copyright (c) OpenMMLab. All rights reserved.

add_subdirectory(ops)
add_subdirectory(optimizer)
2 changes: 1 addition & 1 deletion mmdeploy/apis/onnx/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def model_to_graph__custom_optimizer(ctx, *args, **kwargs):
assert isinstance(
custom_passes, Callable
), f'Expect a callable onnx_custom_passes, get {type(custom_passes)}.'
graph, params_dict, torch_out = custom_passes(graph, params_dict,
graph, params_dict, torch_out = custom_passes(ctx, graph, params_dict,
torch_out)

return graph, params_dict, torch_out
Expand Down
3 changes: 2 additions & 1 deletion mmdeploy/apis/onnx/passes/optimize_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from mmdeploy.utils import get_root_logger


def optimize_onnx(graph, params_dict, torch_out):
def optimize_onnx(ctx, graph, params_dict, torch_out):
"""The optimize callback of the onnx model."""
logger = get_root_logger()
logger.info('Execute onnx optimize passes.')
try:
Expand Down
74 changes: 72 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

from setuptools import find_packages, setup

try:
from torch.utils.cpp_extension import BuildExtension
cmd_class = {'build_ext': BuildExtension}
except ModuleNotFoundError:
cmd_class = {}
print('Skip building ext ops due to the absence of torch.')
pwd = os.path.dirname(__file__)
version_file = 'mmdeploy/version.py'

Expand Down Expand Up @@ -96,6 +102,70 @@ def gen_packages_items():
return packages


def get_extensions():
extensions = []
ext_name = 'mmdeploy.backend.torchscript.ts_optimizer'
import glob
import platform

from torch.utils.cpp_extension import CppExtension

try:
import psutil
num_cpu = len(psutil.Process().cpu_affinity())
cpu_use = max(4, num_cpu - 1)
except (ModuleNotFoundError, AttributeError):
cpu_use = 4

os.environ.setdefault('MAX_JOBS', str(cpu_use))
define_macros = []

# Before PyTorch1.8.0, when compiling CUDA code, `cxx` is a
# required key passed to PyTorch. Even if there is no flag passed
# to cxx, users also need to pass an empty list to PyTorch.
# Since PyTorch1.8.0, it has a default value so users do not need
# to pass an empty list anymore.
# More details at https://github.com/pytorch/pytorch/pull/45956
extra_compile_args = {'cxx': []}

# c++14 is required.
# However, in the windows environment, some standard libraries
# will depend on c++17 or higher. In fact, for the windows
# environment, the compiler will choose the appropriate compiler
# to compile those cpp files, so there is no need to add the
# argument
if platform.system() != 'Windows':
extra_compile_args['cxx'] = ['-std=c++14']

include_dirs = []

op_files = glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/*.cpp'
) + glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/*.cpp'
) + glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/*.cpp')
extension = CppExtension

# c++14 is required.
# However, in the windows environment, some standard libraries
# will depend on c++17 or higher. In fact, for the windows
# environment, the compiler will choose the appropriate compiler
# to compile those cpp files, so there is no need to add the
# argument
if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
extra_compile_args['nvcc'] += ['-std=c++14']

ext_ops = extension(
name=ext_name,
sources=op_files,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args)
extensions.append(ext_ops)
return extensions


if __name__ == '__main__':
setup(
name='mmdeploy',
Expand Down Expand Up @@ -128,6 +198,6 @@ def gen_packages_items():
'build': parse_requirements('requirements/build.txt'),
'optional': parse_requirements('requirements/optional.txt'),
},
ext_modules=[],
cmdclass={},
ext_modules=get_extensions(),
cmdclass=cmd_class,
zip_safe=False)
10 changes: 5 additions & 5 deletions tests/test_apis/test_onnx_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_merge_shape_concate():
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out

Expand Down Expand Up @@ -82,7 +82,7 @@ def test_peephole():
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out

Expand Down Expand Up @@ -148,7 +148,7 @@ def test_flatten_cls_head():
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph)
return graph, params_dict, torch_out

Expand Down Expand Up @@ -199,7 +199,7 @@ def test_fuse_select_assign():
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph, params_dict)
return graph, params_dict, torch_out

Expand Down Expand Up @@ -247,7 +247,7 @@ def test_common_subgraph_elimination():
except ImportError:
pytest.skip('pass not found.')

def _optimize_onnx(graph, params_dict, torch_out):
def _optimize_onnx(ctx, graph, params_dict, torch_out):
opt_pass(graph, params_dict)
return graph, params_dict, torch_out

Expand Down
6 changes: 6 additions & 0 deletions tools/package_tools/mmdeploy_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def _remove_in_mmdeploy(path):
for ncnn_ext_path in ncnn_ext_paths:
os.remove(ncnn_ext_path)

# remove ts_optmizer
ts_optimizer_paths = glob(
osp.join(mmdeploy_dir, 'mmdeploy/backend/torchscript/ts_optimizer.*'))
for ts_optimizer_path in ts_optimizer_paths:
os.remove(ts_optimizer_path)


def build_mmdeploy(cfg, mmdeploy_dir, dist_dir=None):
cmake_flags = cfg.get('cmake_flags', [])
Expand Down