Skip to content

Commit

Permalink
[Feature] Update deploy test tools (#553)
Browse files Browse the repository at this point in the history
* add trt test tool

* create deploy_test, update document

* fix with isort

* move import inside __init__

* remove comment, fix doc

* update document
  • Loading branch information
q.yao authored May 25, 2021
1 parent 66b0525 commit dc5d53b
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 17 deletions.
31 changes: 20 additions & 11 deletions docs/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ Description of arguments:

**Note**: This tool is still experimental. Some customized operators are not supported for now.

### Evaluate ONNX model with ONNXRuntime
### Evaluate ONNX model

We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend.
We provide `tools/deploy_test.py` to evaluate ONNX model with different backend.

#### Prerequisite

Expand All @@ -88,12 +88,15 @@ We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend.
pip install onnx onnxruntime-gpu
```

- Install TensorRT following [how-to-build-tensorrt-plugins-in-mmcv](https://mmcv.readthedocs.io/en/latest/tensorrt_plugin.html#how-to-build-tensorrt-plugins-in-mmcv)(optional)

#### Usage

```bash
python tools/ort_test.py \
python tools/deploy_test.py \
${CONFIG_FILE} \
${ONNX_FILE} \
${MODEL_FILE} \
${BACKEND} \
--out ${OUTPUT_FILE} \
--eval ${EVALUATION_METRICS} \
--show \
Expand All @@ -106,7 +109,8 @@ python tools/ort_test.py \
Description of all arguments

- `config`: The path of a model config file.
- `model`: The path of a ONNX model file.
- `model`: The path of a converted model file.
- `backend`: Backend of the inference, options: `onnxruntime`, `tensorrt`.
- `--out`: The path of output result file in pickle format.
- `--format-only` : Format the output results without perform evaluation. It is useful when you want to format the result to a specific format and submit it to the test server. If not specified, it will be set to `False`. Note that this argument is **mutually exclusive** with `--eval`.
- `--eval`: Evaluation metrics, which depends on the dataset, e.g., "mIoU" for generic datasets, and "cityscapes" for Cityscapes. Note that this argument is **mutually exclusive** with `--format-only`.
Expand All @@ -118,12 +122,17 @@ Description of all arguments

#### Results and Models

| Model | Config | Dataset | Metric | PyTorch | ONNXRuntime |
| :--------: | :--------------------------------------------: | :--------: | :----: | :-----: | :---------: |
| FCN | fcn_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 72.2 | 72.2 |
| PSPNet | pspnet_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.2 | 78.1 |
| deeplabv3 | deeplabv3_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.5 | 78.3 |
| deeplabv3+ | deeplabv3plus_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.9 | 78.7 |
| Model | Config | Dataset | Metric | PyTorch | ONNXRuntime | TensorRT-fp32 | TensorRT-fp16 |
| :--------: | :---------------------------------------------: | :--------: | :----: | :-----: | :---------: | :-----------: | :-----------: |
| FCN | fcn_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 72.2 | 72.2 | 72.2 | 72.2 |
| PSPNet | pspnet_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 77.8 | 77.8 | 77.8 | 77.8 |
| deeplabv3 | deeplabv3_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 79.0 | 79.0 | 79.0 | 79.0 |
| deeplabv3+ | deeplabv3plus_r50-d8_512x1024_40k_cityscapes.py | cityscapes | mIoU | 79.6 | 79.5 | 79.5 | 79.5 |
| PSPNet | pspnet_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.2 | 78.1 | | |
| deeplabv3 | deeplabv3_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.5 | 78.3 | | |
| deeplabv3+ | deeplabv3plus_r50-d8_769x769_40k_cityscapes.py | cityscapes | mIoU | 78.9 | 78.7 | | |

**Note**: TensorRT is only available on configs with `whole mode`.

### Convert to TorchScript (experimental)

Expand Down
72 changes: 66 additions & 6 deletions tools/ort_test.py → tools/deploy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import os
import os.path as osp
import warnings
from typing import Any, Iterable

import mmcv
import numpy as np
import onnxruntime as ort
import torch
from mmcv.parallel import MMDataParallel
from mmcv.runner import get_dist_info
Expand All @@ -18,8 +18,10 @@

class ONNXRuntimeSegmentor(BaseSegmentor):

def __init__(self, onnx_file, cfg, device_id):
def __init__(self, onnx_file: str, cfg: Any, device_id: int):
super(ONNXRuntimeSegmentor, self).__init__()
import onnxruntime as ort

# get the custom op path
ort_custom_op_path = ''
try:
Expand Down Expand Up @@ -60,7 +62,8 @@ def encode_decode(self, img, img_metas):
def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')

def simple_test(self, img, img_meta, **kwargs):
def simple_test(self, img: torch.Tensor, img_meta: Iterable,
**kwargs) -> list:
device_type = img.device.type
self.io_binding.bind_input(
name='input',
Expand All @@ -87,11 +90,63 @@ def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')


def parse_args():
class TensorRTSegmentor(BaseSegmentor):

def __init__(self, trt_file: str, cfg: Any, device_id: int):
super(TensorRTSegmentor, self).__init__()
from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
try:
load_tensorrt_plugin()
except (ImportError, ModuleNotFoundError):
warnings.warn('If input model has custom op from mmcv, \
you may have to build mmcv with TensorRT from source.')
model = TRTWraper(
trt_file, input_names=['input'], output_names=['output'])

self.model = model
self.device_id = device_id
self.cfg = cfg
self.test_mode = cfg.model.test_cfg.mode

def extract_feat(self, imgs):
raise NotImplementedError('This method is not implemented.')

def encode_decode(self, img, img_metas):
raise NotImplementedError('This method is not implemented.')

def forward_train(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')

def simple_test(self, img: torch.Tensor, img_meta: Iterable,
**kwargs) -> list:
with torch.cuda.device(self.device_id), torch.no_grad():
seg_pred = self.model({'input': img})['output']
seg_pred = seg_pred.detach().cpu().numpy()
# whole might support dynamic reshape
ori_shape = img_meta[0]['ori_shape']
if not (ori_shape[0] == seg_pred.shape[-2]
and ori_shape[1] == seg_pred.shape[-1]):
seg_pred = torch.from_numpy(seg_pred).float()
seg_pred = torch.nn.functional.interpolate(
seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
seg_pred = seg_pred.long().detach().cpu().numpy()
seg_pred = seg_pred[0]
seg_pred = list(seg_pred)
return seg_pred

def aug_test(self, imgs, img_metas, **kwargs):
raise NotImplementedError('This method is not implemented.')


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description='mmseg onnxruntime backend test (and eval) a model')
description='mmseg backend test (and eval)')
parser.add_argument('config', help='test config file path')
parser.add_argument('model', help='Input model file')
parser.add_argument(
'--backend',
help='Backend of the model.',
choices=['onnxruntime', 'tensorrt'])
parser.add_argument('--out', help='output result file in pickle format')
parser.add_argument(
'--format-only',
Expand Down Expand Up @@ -163,7 +218,12 @@ def main():

# load onnx config and meta
cfg.model.train_cfg = None
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)

if args.backend == 'onnxruntime':
model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
elif args.backend == 'tensorrt':
model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0)

model.CLASSES = dataset.CLASSES
model.PALETTE = dataset.PALETTE

Expand Down

0 comments on commit dc5d53b

Please sign in to comment.