Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Refactor backend API #1869

Open
wants to merge 35 commits into
base: dev-1.x
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add torchscript ir
  • Loading branch information
grimoire committed Feb 24, 2023
commit 460b6cc8eee368ec34d4199a88431268f182872f
89 changes: 21 additions & 68 deletions mmdeploy/apis/torch_jit/trace.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import torch

from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import IR, Backend, get_ir_config, get_root_logger
from mmdeploy.ir.torchscript import export
from mmdeploy.utils import Backend
from ..core import PIPELINE_MANAGER


@@ -27,9 +25,10 @@ def trace(func: torch.nn.Module,
>>> func = create_model()
>>> inputs = get_input_tensor()
>>>
>>> jit_model = trace(
>>> trace(
>>> func,
>>> inputs,
>>> output_prefix,
>>> backend='torchscript',
>>> check_trace=False)
>>>
@@ -55,69 +54,23 @@ def trace(func: torch.nn.Module,
Returns:
torch.jit.TracedModule: The traced torch jit model.
"""
logger = get_root_logger()
logger.info('Export PyTorch model to torchscript.')
if output_path_prefix is None:
from tempfile import NamedTemporaryFile
output_path = NamedTemporaryFile(suffix='.pth').name
else:
output_path = output_path_prefix + '.pth'

def _add_or_update(cfg: dict, key: str, val: Any):
if key in cfg and isinstance(cfg[key], dict) and isinstance(val, dict):
cfg[key].update(val)
else:
cfg[key] = val

context_info = deepcopy(context_info)
deploy_cfg = context_info.pop('deploy_cfg', dict())
ir_config = dict(type='torchscript')
_add_or_update(deploy_cfg, 'ir_config', ir_config)

if isinstance(backend, Backend):
backend = backend.value
backend_config = dict(type=backend)
_add_or_update(deploy_cfg, 'backend_config', backend_config)

context_info['cfg'] = deploy_cfg
if 'backend' not in context_info:
context_info['backend'] = backend
elif context_info['backend'] != backend:
logger.warning(
f'Find backend {context_info["backend"]} in context_info.'
f' Expect {backend}.')
if 'ir' not in context_info:
context_info['ir'] = IR.TORCHSCRIPT
elif context_info['ir'] != backend:
logger.warning(f'Find ir {context_info["ir"]} in context_info.'
f' Expect {IR.TORCHSCRIPT}.')

# patch model
if isinstance(func, torch.nn.Module):
ir = IR.get(get_ir_config(deploy_cfg)['type'])
func = patch_model(func, cfg=deploy_cfg, backend=backend, ir=ir)

with RewriterContext(**context_info), torch.no_grad():

# patch input_metas
if input_metas is not None:
assert isinstance(
input_metas, dict
), f'Expect input_metas type is dict, get {type(input_metas)}.'
model_forward = func.forward
func.forward = partial(func.forward, **input_metas)

# for exporting models with weight that depends on inputs
func(*inputs) if isinstance(inputs, Sequence) \
else func(inputs)
ts_model = torch.jit.trace(
func,
inputs,
check_trace=check_trace,
check_tolerance=check_tolerance)

if input_metas is not None:
func.forward = model_forward

# save model
if output_path_prefix is not None:
output_path = output_path_prefix + '.pt'
logger.info(f'Save PyTorch model: {output_path}.')
torch.jit.save(ts_model, output_path)
export(
func,
inputs,
output_path,
backend=backend,
rewrite_context=deploy_cfg,
check_trace=check_trace,
check_tolerance=check_tolerance,
const_args=input_metas)

ts_model = torch.jit.load(output_path)

return ts_model
8 changes: 8 additions & 0 deletions mmdeploy/ir/torchscript/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ir_manager import TorchScriptIRParam, TorchScriptManager

export = TorchScriptManager.export
export_from_param = TorchScriptManager.export_from_param
is_available = TorchScriptManager.is_available

__all__ = ['TorchScriptManager', 'TorchScriptIRParam']
144 changes: 144 additions & 0 deletions mmdeploy/ir/torchscript/ir_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from dataclasses import dataclass
from typing import Any, Dict, Optional, Union

from mmdeploy.utils.constants import Backend
from ..base import IR_MANAGERS, BaseIRManager, BaseIRParam


@dataclass
class TorchScriptIRParam(BaseIRParam):
"""TorchScript IR params.

Args:
args (Any): The arguments of the model.
work_dir (str): The working directory to save the output.
file_name (str): The file name of the output. postfix can be omitted.
input_names (List[str]): The names to assign to the input of the ir.
output_names (List[str]): The names to assign to the output of the ir.
dynamic_axes (Dict): Determine the dynamic axes of the inputs. It not
given, all axes will be static.
backend (str): The expected backend of the ir.
rewrite_context (Dict): Provide information to the rewriter.
const_args (Any): The constant args of the model.
check_trace (bool): Check outputs after trace.
check_tolerance (float): The tolerance of the check outputs.
"""

# latent fields
_default_postfix = '.pth'

# class fields
const_args: Any = None
check_trace: bool = True
check_tolerance: float = 1e-05


@IR_MANAGERS.register('onnx', params=TorchScriptIRParam)
class TorchScriptManager(BaseIRManager):
"""TorchScript IR Manager."""

@classmethod
def export(cls,
model: Any,
args: Any,
output_path: str,
backend: Union[Backend, str] = 'default',
rewrite_context: Dict = None,
check_trace: bool = True,
check_tolerance: float = 1e-05,
const_args: Optional[Dict] = None):
"""A wrapper of `torch.jit.trace` with some enhancement.

Examples:
>>> from mmdeploy.ir.torchscript import export
>>>
>>> func = create_model()
>>> inputs = get_input_tensor()
>>>
>>> jit_model = export(
>>> func,
>>> inputs,
>>> backend='torchscript',
>>> check_trace=False)
>>>

Args:
func (torch.nn.Module): A Python function or `torch.nn.Module`
that will be run with `example_inputs`.
inputs (torch.Tensor, Tuple): A tuple of example inputs that
will be passed to the function while tracing.
output_path (str): The output path.
backend (Backend|str): Which backend will the graph be used.
Different backend would generate different graph.
const_args (Dict): The constant inputs of the model.
rewrite_context (Dict): The information that would be used in
the context of exporting.
check_trace (bool): Check if the same inputs run through traced
code produce the same outputs.
check_tolerance (float): Floating-point comparison tolerance to
use in the checker procedure.

Returns:
torch.jit.TracedModule: The traced torch jit model.
"""
from .trace import trace
trace(
model,
args,
output_path,
backend=backend,
rewrite_context=rewrite_context,
check_trace=check_trace,
check_tolerance=check_tolerance,
const_args=const_args)

@classmethod
def export_from_param(cls, model: Any, param: TorchScriptIRParam):
"""Export model to given ir.

Examples:
>>> from mmdeploy.ir.torchscript import export_from_param
>>>
>>> model = create_model()
>>> param = TorchScriptIRParam(...)
>>>
>>> export_from_param(model, param)

Args:
model (Any): The model to be exported.
params (TorchScriptIRParam): The packed export parameter.
"""

from mmdeploy.utils import get_root_logger
logger = get_root_logger()

# check param validation
param.check()

# get output path
work_dir = param.work_dir
if not isinstance(work_dir, str):
logger.warning('Invalid work_dir. Use `./work_dir` as default.')
work_dir = './work_dir'

assert isinstance(param.file_name, str), ('Expect string file name, '
f'got {type(param.name)}')
output_path = osp.join(param.work_dir, param.file_name)

cls.export(
model,
param.args,
output_path,
backend=param.backend,
rewrite_context=param.rewrite_context,
check_trace=param.check_trace,
check_tolerance=param.check_tolerance,
const_args=param.const_args)

@classmethod
def is_available(cls) -> bool:
"""check if the export is available."""
import importlib
return importlib.util.find_spec('torch') is not None
111 changes: 111 additions & 0 deletions mmdeploy/ir/torchscript/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from functools import partial
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import torch

from mmdeploy.core import RewriterContext, patch_model
from mmdeploy.utils import IR, Backend, get_ir_config, get_root_logger


def trace(func: torch.nn.Module,
inputs: Union[torch.Tensor, Tuple],
output_path: Optional[str] = None,
backend: Union[Backend, str] = 'default',
rewrite_context: Dict = dict(),
check_trace: bool = True,
check_tolerance: float = 1e-05,
const_args: Optional[Dict] = None) -> torch.jit.TracedModule:
"""A wrapper of `torch.jit.trace` with some enhancement.

Examples:
>>> from mmdeploy.ir.torchscript import export
>>>
>>> func = create_model()
>>> inputs = get_input_tensor()
>>>
>>> jit_model = export(
>>> func,
>>> inputs,
>>> output_path,
>>> backend='torchscript',
>>> check_trace=False)
>>>

Args:
func (torch.nn.Module): A Python function or `torch.nn.Module` that
will be run with `example_inputs`.
inputs (torch.Tensor, Tuple): A tuple of example inputs that will be
passed to the function while tracing.
output_path (str): The output path.
backend (Backend|str): Which backend will the graph be used. Different
backend would generate different graph.
const_args (Dict): The constant inputs of the model.
rewrite_context (Dict): The information that would be used in the
context of exporting.
check_trace (bool): Check if the same inputs run through traced code
produce the same outputs.
check_tolerance (float): Floating-point comparison tolerance to use in
the checker procedure.

Returns:
torch.jit.TracedModule: The traced torch jit model.
"""
logger = get_root_logger()
logger.info('Export PyTorch model to torchscript.')

def _add_or_update(cfg: dict, key: str, val: Any):
if key in cfg and isinstance(cfg[key], dict) and isinstance(val, dict):
cfg[key].update(val)
else:
cfg[key] = val

if rewrite_context is None:
rewrite_context = dict()

rewrite_context = deepcopy(rewrite_context)
ir_config = dict(type='torchscript')
_add_or_update(rewrite_context, 'ir_config', ir_config)

if isinstance(backend, Backend):
backend = backend.value
elif backend is None:
backend = 'default'
backend_config = dict(type=backend)
_add_or_update(rewrite_context, 'backend_config', backend_config)

# patch model
if isinstance(func, torch.nn.Module):
ir = IR.get(get_ir_config(rewrite_context)['type'])
func = patch_model(func, cfg=rewrite_context, backend=backend, ir=ir)

with RewriterContext(
rewrite_context, ir=IR.TORCHSCRIPT,
backend=backend), torch.no_grad():

# patch const_args
if const_args is not None:
assert isinstance(
const_args, dict
), f'Expect const_args type is dict, get {type(const_args)}.'
model_forward = func.forward
func.forward = partial(func.forward, **const_args)

# for exporting models with weight that depends on inputs
func(*inputs) if isinstance(inputs, Sequence) \
else func(inputs)
ts_model = torch.jit.trace(
func,
inputs,
check_trace=check_trace,
check_tolerance=check_tolerance)

if const_args is not None:
func.forward = model_forward

# save model
logger.info(f'Save PyTorch model: {output_path}.')
torch.jit.save(ts_model, output_path)

return ts_model
Loading