Skip to content

Commit

Permalink
Add partition doc and sample code (#599)
Browse files Browse the repository at this point in the history
* update torch2onnx tool to support onnx partition

* add model partition of yolov3

* add cn doc

* update torch2onnx tool to support onnx partition

* add model partition of yolov3

* add cn doc

* add to index.rst

* resolve comment

* resolve comments

* fix lint

* change caption level in docs
  • Loading branch information
RunningLeon authored Jun 28, 2022
1 parent dc5f9c3 commit f568fe7
Show file tree
Hide file tree
Showing 15 changed files with 298 additions and 85 deletions.
12 changes: 12 additions & 0 deletions configs/mmdet/detection/yolov3_partition_onnxruntime_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = ['./detection_onnxruntime_static.py']

onnx_config = dict(input_shape=[608, 608])
partition_config = dict(
type='yolov3_partition',
apply_marks=True,
partition_cfg=[
dict(
save_file='yolov3.onnx',
start=['detector_forward:input'],
end=['yolo_head:input'])
])
8 changes: 4 additions & 4 deletions docs/en/01-how-to-build/build_from_docker.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
## Use Docker Image
# Use Docker Image

We provide two dockerfiles for CPU and GPU respectively. For CPU users, we install MMDeploy with ONNXRuntime, ncnn and OpenVINO backends. For GPU users, we install MMDeploy with TensorRT backend. Besides, users can install mmdeploy with different versions when building the docker image.

### Build docker image
## Build docker image

For CPU users, we can build the docker image with the latest MMDeploy through:

Expand Down Expand Up @@ -37,15 +37,15 @@ cd mmdeploy
docker build docker/CPU/ -t mmdeploy:inside --build-arg USE_SRC_INSIDE=true
```

### Run docker container
## Run docker container

After building the docker image succeed, we can use `docker run` to launch the docker service. GPU docker image for example:

```
docker run --gpus all -it -p 8080:8081 mmdeploy:master-gpu
```

### FAQs
## FAQs

1. CUDA error: the provided PTX was compiled with an unsupported toolchain:

Expand Down
38 changes: 15 additions & 23 deletions docs/en/02-how-to-run/write_config.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## How to write config
# How to write config

This tutorial describes how to write a config for model conversion and deployment. A deployment config includes `onnx config`, `codebase config`, `backend config`.

Expand All @@ -24,11 +24,11 @@ This tutorial describes how to write a config for model conversion and deploymen

<!-- TOC -->

### 1. How to write onnx config
## 1. How to write onnx config

Onnx config to describe how to export a model from pytorch to onnx.

#### Description of onnx config arguments
### Description of onnx config arguments

- `type`: Type of config dict. Default is `onnx`.
- `export_params`: If specified, all parameters will be exported. Set this to False if you want to export an untrained model.
Expand All @@ -39,7 +39,7 @@ Onnx config to describe how to export a model from pytorch to onnx.
- `output_names`: Names to assign to the output nodes of the graph.
- `input_shape`: The height and width of input tensor to the model.

##### Example
### Example

```python
onnx_config = dict(
Expand All @@ -53,13 +53,13 @@ onnx_config = dict(
input_shape=None)
```

#### If you need to use dynamic axes
### If you need to use dynamic axes

If the dynamic shape of inputs and outputs is required, you need to add dynamic_axes dict in onnx config.

- `dynamic_axes`: Describe the dimensional information about input and output.

##### Example
#### Example

```python
dynamic_axes={
Expand All @@ -79,28 +79,28 @@ If the dynamic shape of inputs and outputs is required, you need to add dynamic_
}
```

### 2. How to write codebase config
## 2. How to write codebase config

Codebase config part contains information like codebase type and task type.

#### Description of codebase config arguments
### Description of codebase config arguments

- `type`: Model's codebase, including `mmcls`, `mmdet`, `mmseg`, `mmocr`, `mmedit`.
- `task`: Model's task type, referring to [List of tasks in all codebases](#list-of-tasks-in-all-codebases).

##### Example
#### Example

```python
codebase_config = dict(type='mmcls', task='Classification')
```

### 3. How to write backend config
## 3. How to write backend config

The backend config is mainly used to specify the backend on which model runs and provide the information needed when the model runs on the backend , referring to [ONNX Runtime](../05-supported-backends/onnxruntime.md), [TensorRT](../05-supported-backends/tensorrt.md), [ncnn](../05-supported-backends/ncnn.md), [PPLNN](../05-supported-backends/pplnn.md).

- `type`: Model's backend, including `onnxruntime`, `ncnn`, `pplnn`, `tensorrt`, `openvino`.

#### Example
### Example

```python
backend_config = dict(
Expand All @@ -117,7 +117,7 @@ backend_config = dict(
])
```

### 4. A complete example of mmcls on TensorRT
## 4. A complete example of mmcls on TensorRT

Here we provide a complete deployment config from mmcls on TensorRT.

Expand Down Expand Up @@ -159,7 +159,7 @@ onnx_config = dict(
input_shape=[224, 224])
```

### 5. The name rules of our deployment config
## 5. The name rules of our deployment config

There is a specific naming convention for the filename of deployment config files.

Expand All @@ -171,20 +171,12 @@ There is a specific naming convention for the filename of deployment config file
- `backend name`: Backend's name. Note if you use the quantization function, you need to indicate the quantization type. Just like `tensorrt-int8`.
- `dynamic or static`: Dynamic or static export. Note if the backend needs explicit shape information, you need to add a description of input size with `height x width` format. Just like `dynamic-512x1024-2048x2048`, it means that the min input shape is `512x1024` and the max input shape is `2048x2048`.

#### Example
### Example

```bash
detection_tensorrt-int8_dynamic-320x320-1344x1344.py
```

### 6. How to write model config
## 6. How to write model config

According to model's codebase, write the model config file. Model's config file is used to initialize the model, referring to [MMClassification](https://github.com/open-mmlab/mmclassification/blob/master/docs/tutorials/config.md), [MMDetection](https://github.com/open-mmlab/mmdetection/blob/master/docs_zh-CN/tutorials/config.md), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/master/docs_zh-CN/tutorials/config.md), [MMOCR](https://github.com/open-mmlab/mmocr/tree/main/configs), [MMEditing](https://github.com/open-mmlab/mmediting/blob/master/docs_zh-CN/config.md).

### 7. Reminder

None

### 8. FAQs

None
22 changes: 11 additions & 11 deletions docs/en/06-developer-guide/add_test_units_for_backend_ops.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
## How to add test units for backend ops
# How to add test units for backend ops

This tutorial introduces how to add unit test for backend ops. When you add a custom op under `backend_ops`, you need to add the corresponding test unit. Test units of ops are included in `tests/test_ops/test_ops.py`.

### Prerequisite
## Prerequisite

- `Compile new ops`: After adding a new custom op, needs to recompile the relevant backend, referring to [build.md](../01-how-to-build/build_from_source.md).

### 1. Add the test program test_XXXX()
## 1. Add the test program test_XXXX()

You can put unit test for ops in `tests/test_ops/`. Usually, the following program template can be used for your custom op.

#### example of ops unit test
### example of ops unit test

```python
@pytest.mark.parametrize('backend', [TEST_TENSORRT, TEST_ONNXRT]) # 1.1 backend test class
Expand Down Expand Up @@ -49,26 +49,26 @@ def test_roi_align(backend,
save_dir=save_dir)
```

#### 1.1 backend test class
### 1.1 backend test class

We provide some functions and classes for difference backends, such as `TestOnnxRTExporter`, `TestTensorRTExporter`, `TestNCNNExporter`.

#### 1.2 set parameters of op
### 1.2 set parameters of op

Set some parameters of op, such as ’pool_h‘, ’pool_w‘, ’spatial_scale‘, ’sampling_ratio‘ in roi_align. You can set multiple parameters to test op.

#### 1.3 op input data initialization
### 1.3 op input data initialization

Initialization required input data.

#### 1.4 initialize op model to be tested
### 1.4 initialize op model to be tested

The model containing custom op usually has two forms.

- `torch model`: Torch model with custom operators. Python code related to op is required, refer to `roi_align` unit test.
- `onnx model`: Onnx model with custom operators. Need to call onnx api to build, refer to `multi_level_roi_align` unit test.

#### 1.5 call the backend test class interface
### 1.5 call the backend test class interface

Call the backend test class `run_and_validate` to run and verify the result output by the op on the backend.

Expand All @@ -86,7 +86,7 @@ Call the backend test class `run_and_validate` to run and verify the result outp
save_dir=None):
```

##### Parameter Description
#### Parameter Description

- `model`: Input model to be tested and it can be torch model or any other backend model.
- `input_list`: List of test data, which is mapped to the order of input_names.
Expand All @@ -99,7 +99,7 @@ Call the backend test class `run_and_validate` to run and verify the result outp
- `expected_result`: Expected ground truth values for verification.
- `save_dir`: The folder used to save the output files.

### 2. Test Methods
## 2. Test Methods

Use pytest to call the test function to test ops.

Expand Down
89 changes: 89 additions & 0 deletions docs/en/06-developer-guide/partition_model.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# How to get partitioned ONNX models

MMDeploy supports exporting PyTorch models to partitioned onnx models. With this feature, users can define their partition policy and get partitioned onnx models at ease. In this tutorial, we will briefly introduce how to support partition a model step by step. In the example, we would break YOLOV3 model into two parts and extract the first part without the post-processing (such as anchor generating and NMS) in the onnx model.

## Step 1: Mark inputs/outpupts

To support the model partition, we need to add `Mark` nodes in the ONNX model. This could be done with mmdeploy's `@mark` decorator. Note that to make the `mark` work, the marking operation should be included in a rewriting function.

At first, we would mark the model input, which could be done by marking the input tensor `img` in the `forward` method of `BaseDetector` class, which is the parent class of all detector classes. Thus we name this marking point as `detector_forward` and mark the inputs as `input`. Since there could be three outputs for detectors such as `Mask RCNN`, the outputs are marked as `dets`, `labels`, and `masks`. The following code shows the idea of adding mark functions and calling the mark functions in the rewrite. For source code, you could refer to [mmdeploy/codebase/mmdet/models/detectors/base.py](https://github.com/open-mmlab/mmdeploy/blob/86a50e343a3a45d7bc2ba3256100accc4973e71d/mmdeploy/codebase/mmdet/models/detectors/base.py)

```python
from mmdeploy.core import FUNCTION_REWRITER, mark

@mark(
'detector_forward', inputs=['input'], outputs=['dets', 'labels', 'masks'])
def __forward_impl(ctx, self, img, img_metas=None, **kwargs):
...


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.detectors.base.BaseDetector.forward')
def base_detector__forward(ctx, self, img, img_metas=None, **kwargs):
...
# call the mark function
return __forward_impl(...)
```

Then, we have to mark the output feature of `YOLOV3Head`, which is the input argument `pred_maps` in `get_bboxes` method of `YOLOV3Head` class. We could add a internal function to only mark the `pred_maps` inside [`yolov3_head__get_bboxes`](https://github.com/open-mmlab/mmdeploy/blob/86a50e343a3a45d7bc2ba3256100accc4973e71d/mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py#L14) function as following.

```python
from mmdeploy.core import FUNCTION_REWRITER, mark

@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes')
def yolov3_head__get_bboxes(ctx,
self,
pred_maps,
img_metas,
cfg=None,
rescale=False,
with_nms=True):
# mark pred_maps
@mark('yolo_head', inputs=['pred_maps'])
def __mark_pred_maps(pred_maps):
return pred_maps
pred_maps = __mark_pred_maps(pred_maps)
...
```

Note that `pred_maps` is a list of `Tensor` and it has three elements. Thus, three `Mark` nodes with op name as `pred_maps.0`, `pred_maps.1`, `pred_maps.2` would be added in the onnx model.

## Step 2: Add partition config

After marking necessary nodes that would be used to split the model, we could add a deployment config file `configs/mmdet/detection/yolov3_partition_onnxruntime_static.py`. If you are not familiar with how to write config, you could check [write_config.md](../02-how-to-run/write_config.md).

In the config file, we need to add `partition_config`. The key part is `partition_cfg`, which contains elements of dict that designates the start nodes and end nodes of each model segments. Since we only want to keep `YOLOV3` without post-processing, we could set the `start` as `['detector_forward:input']`, and `end` as `['yolo_head:input']`. Note that `start` and `end` can have multiple marks.

```python
_base_ = ['./detection_onnxruntime_static.py']

onnx_config = dict(input_shape=[608, 608])
partition_config = dict(
type='yolov3_partition', # the partition policy name
apply_marks=True, # should always be set to True
partition_cfg=[
dict(
save_file='yolov3.onnx', # filename to save the partitioned onnx model
start=['detector_forward:input'], # [mark_name:input/output, ...]
end=['yolo_head:input']) # [mark_name:input/output, ...]
])

```

## Step 3: Get partitioned onnx models

Once we have marks of nodes and the deployment config with `parition_config` being set properly, we could use the [tool](../useful_tools.md) `torch2onnx` to export the model to onnx and get the partition onnx files.

```shell
python tools/torch2onnx.py \
configs/mmdet/detection/yolov3_partition_onnxruntime_static.py \
../mmdetection/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py \
https://download.openmmlab.com/mmdetection/v2.0/yolo/yolov3_d53_mstrain-608_273e_coco/yolov3_d53_mstrain-608_273e_coco_20210518_115020-a2c3acb8.pth \
../mmdetection/demo/demo.jpg \
--work-dir ./work-dirs/mmdet/yolov3/ort/partition
```

After run the script above, we would have the partitioned onnx file `yolov3.onnx` in the `work-dir`. You can use the visualization tool [netron](https://netron.app/) to check the model structure.

With the partitioned onnx file, you could refer to [useful_tools.md](../useful_tools.md) to do the following procedures such as `onnx2ncnn`, `onnx2tensorrt`.
10 changes: 5 additions & 5 deletions docs/en/06-developer-guide/support_new_backend.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
## How to support new backends
# How to support new backends

MMDeploy supports a number of backend engines. We welcome the contribution of new backends. In this tutorial, we will introduce the general procedures to support a new backend in MMDeploy.

### Prerequisites
## Prerequisites

Before contributing the codes, there are some requirements for the new backend that need to be checked:

- The backend must support ONNX as IR.
- If the backend requires model files or weight files other than a ".onnx" file, a conversion tool that converts the ".onnx" file to model files and weight files is required. The tool can be a Python API, a script, or an executable program.
- It is highly recommended that the backend provides a Python interface to load the backend files and inference for validation.

### Support backend conversion
## Support backend conversion

The backends in MMDeploy must support the ONNX. The backend loads the ".onnx" file directly, or converts the ".onnx" to its own format using the conversion tool. In this section, we will introduce the steps to support backend conversion.

Expand Down Expand Up @@ -155,7 +155,7 @@ The backends in MMDeploy must support the ONNX. The backend loads the ".onnx" fi

7. Add docstring and unit tests for new code :).

### Support backend inference
## Support backend inference

Although the backend engines are usually implemented in C/C++, it is convenient for testing and debugging if the backend provides Python inference interface. We encourage the contributors to support backend inference in the Python interface of MMDeploy. In this section we will introduce the steps to support backend inference.

Expand Down Expand Up @@ -230,7 +230,7 @@ Although the backend engines are usually implemented in C/C++, it is convenient

5. Add docstring and unit tests for new code :).

### Support new backends using MMDeploy as a third party
## Support new backends using MMDeploy as a third party

Previous parts show how to add a new backend in MMDeploy, which requires changing its source codes. However, if we treat MMDeploy as a third party, the methods above are no longer efficient. To this end, adding a new backend requires us pre-install another package named `aenum`. We can install it directly through `pip install aenum`.

Expand Down
8 changes: 4 additions & 4 deletions docs/en/06-developer-guide/support_new_model.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
## How to support new models
# How to support new models

We provide several tools to support model conversion.

### Function Rewriter
## Function Rewriter

The PyTorch neural network is written in python that eases the development of the algorithm. But the use of Python control flow and third-party libraries make it difficult to export the network to an intermediate representation. We provide a 'monkey patch' tool to rewrite the unsupported function to another one that can be exported. Here is an example:

Expand All @@ -26,7 +26,7 @@ It is easy to use the function rewriter. Just add a decorator with arguments:

The arguments are the same as the original function, except a context `ctx` as the first argument. The context provides some useful information such as the deployment config `ctx.cfg` and the original function (which has been overridden) `ctx.origin_func`.

### Module Rewriter
## Module Rewriter

If you want to replace a whole module with another one, we have another rewriter as follows:

Expand Down Expand Up @@ -66,7 +66,7 @@ Just like function rewriter, add a decorator with arguments:

All instances of the module in the network will be replaced with instances of this new class. The original module and the deployment config will be passed as the first two arguments.

### Custom Symbolic
## Custom Symbolic

The mappings between PyTorch and ONNX are defined in PyTorch with symbolic functions. The custom symbolic function can help us to bypass some ONNX nodes which are unsupported by inference engine.

Expand Down
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ You can switch between Chinese and English documents in the lower-left corner of
06-developer-guide/support_new_backend.md
06-developer-guide/add_test_units_for_backend_ops.md
06-developer-guide/test_rewritten_models.md
06-developer-guide/partition_model.md

.. toctree::
:maxdepth: 1
Expand Down
Loading

0 comments on commit f568fe7

Please sign in to comment.