Skip to content

Commit

Permalink
add code for export onnx model
Browse files Browse the repository at this point in the history
  • Loading branch information
Dean authored and Dean committed Feb 17, 2021
1 parent 84fde67 commit 719cd51
Show file tree
Hide file tree
Showing 16 changed files with 675 additions and 48 deletions.
59 changes: 57 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Center-based 3D Object Detection and Tracking
# CenterNet-PonintPillars Pytroch model convert to ONNX
Welcome to CenterNet! This project is fork from [tianweiy/CenterPoint](https://github.com/tianweiy/CenterPoint). I implement some code to export CenterNet-PonintPillars ONNX model.

Center-based 3D Object Detection and Tracking

3D Object Detection and Tracking using center points in the bird-eye view.

Expand All @@ -19,7 +22,6 @@


## NEWS

[2021-01-06] CenterPoint v1.0 is released. Without bells and whistles, we rank first among all Lidar-only methods on Waymo Open Dataset with a single model that runs at 11 FPS. Check out CenterPoint's model zoo for [Waymo](configs/waymo/README.md) and [nuScenes](configs/nusc/README.md).

[2020-12-11] 3 out of the top 4 entries in the recent NeurIPS 2020 [nuScenes 3D Detection challenge](https://www.nuscenes.org/object-detection?externalData=all&mapData=all&modalities=Any) used CenterPoint. Congratualations to other participants and please stay tuned for more updates on nuScenes and Waymo soon.
Expand Down Expand Up @@ -103,6 +105,59 @@ Then run a demo by ```python tools/demo.py```. If setup corectly, you will see a

Please refer to [GETTING_START](docs/GETTING_START.md) to prepare the data. Then follow the instruction there to reproduce our detection and tracking results. All detection configurations are included in [configs](configs) and we provide the scripts for all tracking experiments in [tracking_scripts](tracking_scripts).

## Export ONNX
I divide Pointpillars model into two parts, pfe(include PillarFeatureNet) and rpn(include RPN and CenterHead). The PointPillarsScatter isn't exported. I use ScatterND node instead of PointPillarsScatter.

- Install packages
```shell
pip install onnx onnx-simplifier onnxruntime
```
- step 1. Download the [trained model(latest.pth)](https://drive.google.com/drive/folders/1K_wHrBo6yRSG7H7UUjKI4rPnyEA8HvOp) and nuscenes mini dataset(v1.0-mini.tar)
- step 2 Prepare dataset. Please refer to [docs/NUSC.md](docs/NUSC.md)

- step 3. Export pfe.onnx and rpn.onnx
```shell
python tool/export_pointpillars_onnx.py
```
- step 4. Use onnx-simplify and scripte to simplify pfe.onnx and rpn.onnx.
```shell
python tool/simplify_model.py
```
- step 5. Merge pfe.onnx and rpn.onnx. We use ScatterND node to connect pfe and rpn. TensorRT doesn't support ScatterND operater. If you want to run centernet-pointpillars by TensorRT, you can run pfe.onnx and rpn.onnx respectively.
```shell
python tool/merge_pfe_rpn_model.py
```
All onnx model are saved in [onnx_model](onnx_model).

I add an argument(export_onnx) for export onnx model in [config file](configs/nusc/pp/nusc_centerpoint_pp_02voxel_two_pfn_10sweep_demo_export_onnx.py)

```python
model = dict(
type="PointPillars",
pretrained=None,
export_onnx=True, # for export onnx model
reader=dict(
type="PillarFeatureNet",
num_filters=[64, 64],
num_input_features=5,
with_distance=False,
voxel_size=(0.2, 0.2, 8),
pc_range=(-51.2, -51.2, -5.0, 51.2, 51.2, 3.0),
export_onnx=True, # for export onnx model
),
backbone=dict(type="PointPillarsScatter", ds_factor=1),
neck=dict(
type="RPN",
layer_nums=[3, 5, 5],
ds_layer_strides=[2, 2, 2],
ds_num_filters=[64, 128, 256],
us_layer_strides=[0.5, 1, 2],
us_num_filters=[128, 128, 128],
num_input_features=64,
logger=logging.getLogger("RPN"),
),
```


## License

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
model = dict(
type="PointPillars",
pretrained=None,
# export_onnx=True,
reader=dict(
type="PillarFeatureNet",
num_filters=[64, 64],
num_input_features=5,
with_distance=False,
voxel_size=(0.2, 0.2, 8),
pc_range=(-51.2, -51.2, -5.0, 51.2, 51.2, 3.0),
# export_onnx=True,
),
backbone=dict(type="PointPillarsScatter", ds_factor=1),
neck=dict(
Expand Down Expand Up @@ -83,12 +85,12 @@
# dataset settings
dataset_type = "NuScenesDataset"
nsweeps = 10
data_root = "data/nuScenes"
data_root = "/home/dean/dataset/nuscenes"

db_sampler = dict(
type="GT-AUG",
enable=False,
db_info_path="data/nuScenes/dbinfos_train_10sweeps_withvelo.pkl",
db_info_path="/home/dean/dataset/nuscenes/dbinfos_train_10sweeps_withvelo.pkl",
sample_groups=[
dict(car=2),
dict(truck=3),
Expand Down Expand Up @@ -159,8 +161,8 @@
dict(type="Reformat"),
]

train_anno = "demo/nuScenes/demo_infos.pkl"
val_anno = "demo/nuScenes/demo_infos.pkl"
train_anno = "/home/dean/dataset/nuscenes/dbinfos_train_10sweeps_withvelo.pkl"
val_anno = "/home/dean/dataset/nuscenes/infos_val_10sweeps_withvelo_filter_True.pkl"
test_anno = None

data = dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import itertools
import logging

from det3d.utils.config_tool import get_downsample_factor

tasks = [
dict(num_class=1, class_names=["car"]),
dict(num_class=2, class_names=["truck", "construction_vehicle"]),
dict(num_class=2, class_names=["bus", "trailer"]),
dict(num_class=1, class_names=["barrier"]),
dict(num_class=2, class_names=["motorcycle", "bicycle"]),
dict(num_class=2, class_names=["pedestrian", "traffic_cone"]),
]

class_names = list(itertools.chain(*[t["class_names"] for t in tasks]))

# training and testing settings
target_assigner = dict(
tasks=tasks,
)


# model settings
model = dict(
type="PointPillars",
pretrained=None,
export_onnx=True,
reader=dict(
type="PillarFeatureNet",
num_filters=[64, 64],
num_input_features=5,
with_distance=False,
voxel_size=(0.2, 0.2, 8),
pc_range=(-51.2, -51.2, -5.0, 51.2, 51.2, 3.0),
export_onnx=True,
),
backbone=dict(type="PointPillarsScatter", ds_factor=1),
neck=dict(
type="RPN",
layer_nums=[3, 5, 5],
ds_layer_strides=[2, 2, 2],
ds_num_filters=[64, 128, 256],
us_layer_strides=[0.5, 1, 2],
us_num_filters=[128, 128, 128],
num_input_features=64,
logger=logging.getLogger("RPN"),
),
bbox_head=dict(
# type='RPNHead',
type="CenterHead",
in_channels=sum([128, 128, 128]),
tasks=tasks,
dataset='nuscenes',
weight=0.25,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 1.0, 1.0],
common_heads={'reg': (2, 2), 'height': (1, 2), 'dim':(3, 2), 'rot':(2, 2), 'vel': (2, 2)}, # (output_channel, num_conv)
),
)

assigner = dict(
target_assigner=target_assigner,
out_size_factor=get_downsample_factor(model),
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
)


train_cfg = dict(assigner=assigner)

test_cfg = dict(
post_center_limit_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0],
max_per_img=500,
nms=dict(
nms_pre_max_size=1000,
nms_post_max_size=83,
nms_iou_threshold=0.2,
),
score_threshold=0.1,
pc_range=[-51.2, -51.2],
out_size_factor=get_downsample_factor(model),
voxel_size=[0.2, 0.2]
)

# dataset settings
dataset_type = "NuScenesDataset"
nsweeps = 10
data_root = "/home/dean/dataset/nuscenes"

db_sampler = dict(
type="GT-AUG",
enable=False,
db_info_path="/home/dean/dataset/nuscenes/dbinfos_train_10sweeps_withvelo.pkl",
sample_groups=[
dict(car=2),
dict(truck=3),
dict(construction_vehicle=7),
dict(bus=4),
dict(trailer=6),
dict(barrier=2),
dict(motorcycle=6),
dict(bicycle=6),
dict(pedestrian=2),
dict(traffic_cone=2),
],
db_prep_steps=[
dict(
filter_by_min_num_points=dict(
car=5,
truck=5,
bus=5,
trailer=5,
construction_vehicle=5,
traffic_cone=5,
barrier=5,
motorcycle=5,
bicycle=5,
pedestrian=5,
)
),
dict(filter_by_difficulty=[-1],),
],
global_random_rotation_range_per_object=[0, 0],
rate=1.0,
)
train_preprocessor = dict(
mode="train",
shuffle_points=True,
global_rot_noise=[-0.3925, 0.3925],
global_scale_noise=[0.95, 1.05],
db_sampler=db_sampler,
class_names=class_names,
)

val_preprocessor = dict(
mode="val",
shuffle_points=False,
)

voxel_generator = dict(
range=[-51.2, -51.2, -5.0, 51.2, 51.2, 3.0],
voxel_size=[0.2, 0.2, 8],
max_points_in_voxel=20,
max_voxel_num=[30000, 60000],
)

train_pipeline = [
dict(type="LoadPointCloudFromFile", dataset=dataset_type),
dict(type="LoadPointCloudAnnotations", with_bbox=True),
dict(type="Preprocess", cfg=train_preprocessor),
dict(type="Voxelization", cfg=voxel_generator),
dict(type="AssignLabel", cfg=train_cfg["assigner"]),
dict(type="Reformat"),
]
test_pipeline = [
dict(type="LoadPointCloudFromFile", dataset=dataset_type),
dict(type="LoadPointCloudAnnotations", with_bbox=True),
dict(type="Preprocess", cfg=val_preprocessor),
dict(type="Voxelization", cfg=voxel_generator),
dict(type="AssignLabel", cfg=train_cfg["assigner"]),
dict(type="Reformat"),
]

train_anno = "/home/dean/dataset/nuscenes/dbinfos_train_10sweeps_withvelo.pkl"
val_anno = "/home/dean/dataset/nuscenes/infos_val_10sweeps_withvelo_filter_True.pkl"
test_anno = None

data = dict(
samples_per_gpu=4,
workers_per_gpu=8,
train=dict(
type=dataset_type,
root_path=data_root,
info_path=train_anno,
ann_file=train_anno,
nsweeps=nsweeps,
class_names=class_names,
pipeline=train_pipeline,
),
val=dict(
type=dataset_type,
root_path=data_root,
info_path=val_anno,
test_mode=True,
ann_file=val_anno,
nsweeps=nsweeps,
class_names=class_names,
pipeline=test_pipeline,
),
test=dict(
type=dataset_type,
root_path=data_root,
info_path=test_anno,
ann_file=test_anno,
nsweeps=nsweeps,
class_names=class_names,
pipeline=test_pipeline,
),
)


optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
# optimizer
optimizer = dict(
type="adam", amsgrad=0.0, wd=0.01, fixed_wd=True, moving_average=False,
)
lr_config = dict(
type="one_cycle", lr_max=0.001, moms=[0.95, 0.85], div_factor=10.0, pct_start=0.4,
)

checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=5,
hooks=[
dict(type="TextLoggerHook"),
# dict(type='TensorboardLoggerHook')
],
)
# yapf:enable
# runtime settings
total_epochs = 20
device_ids = range(8)
dist_params = dict(backend="nccl", init_method="env://")
log_level = "INFO"
work_dir = './work_dirs/{}/'.format(__file__[__file__.rfind('/') + 1:-3])
load_from = None
resume_from = None
workflow = [('train', 1)]
5 changes: 3 additions & 2 deletions det3d/models/bbox_heads/center_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def forward(self, x, *kwargs):

for task in self.tasks:
ret_dicts.append(task(x))

self.ret_dicts = ret_dicts
return ret_dicts

def _sigmoid(self, x):
Expand Down Expand Up @@ -454,7 +454,8 @@ def post_processing(self, batch_box_preds, batch_hm, test_cfg, post_center_range
for i in range(batch_size):
box_preds = batch_box_preds[i]
hm_preds = batch_hm[i]



scores, labels = torch.max(hm_preds, dim=-1)

score_mask = scores > test_cfg.score_threshold
Expand Down
6 changes: 5 additions & 1 deletion det3d/models/detectors/point_pillars.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ def __init__(
train_cfg=None,
test_cfg=None,
pretrained=None,
export_onnx = False
):
super(PointPillars, self).__init__(
reader, backbone, neck, bbox_head, train_cfg, test_cfg, pretrained
)
self.export_onnx = export_onnx

def extract_feat(self, data):
input_features = self.reader(
Expand Down Expand Up @@ -47,7 +49,9 @@ def forward(self, example, return_loss=True, **kwargs):

x = self.extract_feat(data)
preds = self.bbox_head(x)


if self.export_onnx:
return preds
if return_loss:
return self.bbox_head.loss(example, preds)
else:
Expand Down
Loading

0 comments on commit 719cd51

Please sign in to comment.