Skip to content

Commit

Permalink
Support setup on environment with no PyTorch (#843)
Browse files Browse the repository at this point in the history
  • Loading branch information
q.yao authored Aug 1, 2022
1 parent 2968a18 commit 1a3841f
Showing 1 changed file with 64 additions and 58 deletions.
122 changes: 64 additions & 58 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from setuptools import find_packages, setup

EXT_TYPE = ''
try:
from torch.utils.cpp_extension import BuildExtension
cmd_class = {'build_ext': BuildExtension}
EXT_TYPE = 'torch'
except ModuleNotFoundError:
cmd_class = {}
print('Skip building ext ops due to the absence of torch.')
Expand Down Expand Up @@ -105,64 +107,68 @@ def gen_packages_items():
def get_extensions():
extensions = []
ext_name = 'mmdeploy.backend.torchscript.ts_optimizer'
import glob
import platform

from torch.utils.cpp_extension import CppExtension

try:
import psutil
num_cpu = len(psutil.Process().cpu_affinity())
cpu_use = max(4, num_cpu - 1)
except (ModuleNotFoundError, AttributeError):
cpu_use = 4

os.environ.setdefault('MAX_JOBS', str(cpu_use))
define_macros = []

# Before PyTorch1.8.0, when compiling CUDA code, `cxx` is a
# required key passed to PyTorch. Even if there is no flag passed
# to cxx, users also need to pass an empty list to PyTorch.
# Since PyTorch1.8.0, it has a default value so users do not need
# to pass an empty list anymore.
# More details at https://github.com/pytorch/pytorch/pull/45956
extra_compile_args = {'cxx': []}

# c++14 is required.
# However, in the windows environment, some standard libraries
# will depend on c++17 or higher. In fact, for the windows
# environment, the compiler will choose the appropriate compiler
# to compile those cpp files, so there is no need to add the
# argument
if platform.system() != 'Windows':
extra_compile_args['cxx'] = ['-std=c++14']

include_dirs = []

op_files = glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/*.cpp'
) + glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/*.cpp'
) + glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/passes/onnx/*.cpp')
extension = CppExtension

# c++14 is required.
# However, in the windows environment, some standard libraries
# will depend on c++17 or higher. In fact, for the windows
# environment, the compiler will choose the appropriate compiler
# to compile those cpp files, so there is no need to add the
# argument
if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
extra_compile_args['nvcc'] += ['-std=c++14']

ext_ops = extension(
name=ext_name,
sources=op_files,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args)
extensions.append(ext_ops)

if EXT_TYPE == 'torch':
import glob
import platform

from torch.utils.cpp_extension import CppExtension

try:
import psutil
num_cpu = len(psutil.Process().cpu_affinity())
cpu_use = max(4, num_cpu - 1)
except (ModuleNotFoundError, AttributeError):
cpu_use = 4

os.environ.setdefault('MAX_JOBS', str(cpu_use))
define_macros = []

# Before PyTorch1.8.0, when compiling CUDA code, `cxx` is a
# required key passed to PyTorch. Even if there is no flag passed
# to cxx, users also need to pass an empty list to PyTorch.
# Since PyTorch1.8.0, it has a default value so users do not need
# to pass an empty list anymore.
# More details at https://github.com/pytorch/pytorch/pull/45956
extra_compile_args = {'cxx': []}

# c++14 is required.
# However, in the windows environment, some standard libraries
# will depend on c++17 or higher. In fact, for the windows
# environment, the compiler will choose the appropriate compiler
# to compile those cpp files, so there is no need to add the
# argument
if platform.system() != 'Windows':
extra_compile_args['cxx'] = ['-std=c++14']

include_dirs = []

op_files = glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/*.cpp'
) + glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/ir/*.cpp'
) + glob.glob(
'./csrc/mmdeploy/backend_ops/torchscript/optimizer/passes'
'/onnx/*.cpp')
extension = CppExtension

# c++14 is required.
# However, in the windows environment, some standard libraries
# will depend on c++17 or higher. In fact, for the windows
# environment, the compiler will choose the appropriate compiler
# to compile those cpp files, so there is no need to add the
# argument
if 'nvcc' in extra_compile_args and platform.system() != 'Windows':
extra_compile_args['nvcc'] += ['-std=c++14']

ext_ops = extension(
name=ext_name,
sources=op_files,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args)
extensions.append(ext_ops)

return extensions


Expand Down

0 comments on commit 1a3841f

Please sign in to comment.