diff --git a/.dev_scripts/gather_models.py b/.dev_scripts/gather_models.py index 58919fd444..38270a73b8 100644 --- a/.dev_scripts/gather_models.py +++ b/.dev_scripts/gather_models.py @@ -25,6 +25,7 @@ '_6x_': 73, '_50e_': 50, '_80e_': 80, + '_100e_': 100, '_150e_': 150, '_200e_': 200, '_250e_': 250, diff --git a/README.md b/README.md index 4d0c0a98e2..8a2075919a 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ Support backbones: - [x] PointNet (CVPR'2017) - [x] PointNet++ (NeurIPS'2017) - [x] RegNet (CVPR'2020) +- [x] DGCNN (TOG'2019) Support methods @@ -94,25 +95,27 @@ Support methods - [x] [Group-Free-3D (Arxiv'2021)](configs/groupfree3d/README.md) - [x] [ImVoxelNet (Arxiv'2021)](configs/imvoxelnet/README.md) - [x] [PAConv (CVPR'2021)](configs/paconv/README.md) - -| | ResNet | ResNeXt | SENet |PointNet++ | HRNet | RegNetX | Res2Net | -|--------------------|:--------:|:--------:|:--------:|:---------:|:-----:|:--------:|:-----:| -| SECOND | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| PointPillars | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| FreeAnchor | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| VoteNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| H3DNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| 3DSSD | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| Part-A2 | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| MVXNet | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| CenterPoint | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| SSN | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| ImVoteNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| FCOS3D | ✓ | ☐ | ☐ | ✗ | ☐ | ☐ | ☐ | -| PointNet++ | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| Group-Free-3D | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| ImVoxelNet | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | -| PAConv | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | +- [x] [DGCNN (TOG'2019)](configs/dgcnn/README.md) + +| | ResNet | ResNeXt | SENet |PointNet++ |DGCNN | HRNet | RegNetX | Res2Net | +|--------------------|:--------:|:--------:|:--------:|:---------:|:---------:|:-----:|:--------:|:-----:| +| SECOND | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| PointPillars | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| FreeAnchor | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| VoteNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| H3DNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| 3DSSD | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| Part-A2 | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| MVXNet | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| CenterPoint | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| SSN | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| ImVoteNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| FCOS3D | ✓ | ☐ | ☐ | ✗ | ✗ | ☐ | ☐ | ☐ | +| PointNet++ | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| Group-Free-3D | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| ImVoxelNet | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | +| PAConv | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| DGCNN | ✗ | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | Other features - [x] [Dynamic Voxelization](configs/dynamic_voxelization/README.md) diff --git a/README_zh-CN.md b/README_zh-CN.md index 30c0d4c561..4fbf874a93 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -74,6 +74,7 @@ MMDetection3D 是一个基于 PyTorch 的目标检测开源工具箱, 下一代 - [x] PointNet (CVPR'2017) - [x] PointNet++ (NeurIPS'2017) - [x] RegNet (CVPR'2020) +- [x] DGCNN (TOG'2019) 已支持的算法: @@ -93,25 +94,27 @@ MMDetection3D 是一个基于 PyTorch 的目标检测开源工具箱, 下一代 - [x] [Group-Free-3D (Arxiv'2021)](configs/groupfree3d/README.md) - [x] [ImVoxelNet (Arxiv'2021)](configs/imvoxelnet/README.md) - [x] [PAConv (CVPR'2021)](configs/paconv/README.md) - -| | ResNet | ResNeXt | SENet |PointNet++ | HRNet | RegNetX | Res2Net | -|--------------------|:--------:|:--------:|:--------:|:---------:|:-----:|:--------:|:-----:| -| SECOND | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| PointPillars | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| FreeAnchor | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| VoteNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| H3DNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| 3DSSD | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| Part-A2 | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| MVXNet | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| CenterPoint | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| SSN | ☐ | ☐ | ☐ | ✗ | ☐ | ✓ | ☐ | -| ImVoteNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| FCOS3D | ✓ | ☐ | ☐ | ✗ | ☐ | ☐ | ☐ | -| PointNet++ | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| Group-Free-3D | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | -| ImVoxelNet | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | -| PAConv | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | +- [x] [DGCNN (TOG'2019)](configs/dgcnn/README.md) + +| | ResNet | ResNeXt | SENet |PointNet++ |DGCNN | HRNet | RegNetX | Res2Net | +|--------------------|:--------:|:--------:|:--------:|:---------:|:---------:|:-----:|:--------:|:-----:| +| SECOND | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| PointPillars | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| FreeAnchor | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| VoteNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| H3DNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| 3DSSD | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| Part-A2 | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| MVXNet | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| CenterPoint | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| SSN | ☐ | ☐ | ☐ | ✗ | ✗ | ☐ | ✓ | ☐ | +| ImVoteNet | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| FCOS3D | ✓ | ☐ | ☐ | ✗ | ✗ | ☐ | ☐ | ☐ | +| PointNet++ | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| Group-Free-3D | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| ImVoxelNet | ✓ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | ✗ | +| PAConv | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | ✗ | +| DGCNN | ✗ | ✗ | ✗ | ✗ | ✓ | ✗ | ✗ | ✗ | 其他特性 - [x] [Dynamic Voxelization](configs/dynamic_voxelization/README.md) diff --git a/configs/_base_/models/dgcnn.py b/configs/_base_/models/dgcnn.py new file mode 100644 index 0000000000..61e7272692 --- /dev/null +++ b/configs/_base_/models/dgcnn.py @@ -0,0 +1,28 @@ +# model settings +model = dict( + type='EncoderDecoder3D', + backbone=dict( + type='DGCNNBackbone', + in_channels=9, # [xyz, rgb, normal_xyz], modified with dataset + num_samples=(20, 20, 20), + knn_modes=('D-KNN', 'F-KNN', 'F-KNN'), + radius=(None, None, None), + gf_channels=((64, 64), (64, 64), (64, )), + fa_channels=(1024, ), + act_cfg=dict(type='LeakyReLU', negative_slope=0.2)), + decode_head=dict( + type='DGCNNHead', + fp_channels=(1216, 512), + channels=256, + dropout_ratio=0.5, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + act_cfg=dict(type='LeakyReLU', negative_slope=0.2), + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, # modified with dataset + loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide')) diff --git a/configs/_base_/schedules/seg_cosine_100e.py b/configs/_base_/schedules/seg_cosine_100e.py new file mode 100644 index 0000000000..3b75932b3a --- /dev/null +++ b/configs/_base_/schedules/seg_cosine_100e.py @@ -0,0 +1,8 @@ +# optimizer +# This schedule is mainly used on S3DIS dataset in segmentation task +optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=1e-5) + +# runtime settings +runner = dict(type='EpochBasedRunner', max_epochs=100) diff --git a/configs/dgcnn/README.md b/configs/dgcnn/README.md new file mode 100644 index 0000000000..fa31e43f34 --- /dev/null +++ b/configs/dgcnn/README.md @@ -0,0 +1,43 @@ +# Dynamic Graph CNN for Learning on Point Clouds + +## Introduction + + + +We implement DGCNN and provide the results and checkpoints on S3DIS dataset. + +``` +@article{dgcnn, + title={Dynamic Graph CNN for Learning on Point Clouds}, + author={Wang, Yue and Sun, Yongbin and Liu, Ziwei and Sarma, Sanjay E. and Bronstein, Michael M. and Solomon, Justin M.}, + journal={ACM Transactions on Graphics (TOG)}, + year={2019} +} +``` + +**Notice**: We follow the implementations in the original DGCNN paper and a PyTorch implementation of DGCNN [code](https://github.com/AnTao97/dgcnn.pytorch). + +## Results + +### S3DIS + +| Method | Split | Lr schd | Mem (GB) | Inf time (fps) | mIoU (Val set) | Download | +| :-------------------------------------------------------------------------: | :----: | :--------: | :------: | :------------: | :------------: | :----------------------: | +| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_1 | cosine 100e | 13.1 | | 68.33 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area1/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210731_000734-39658f14.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area1/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210731_000734.log.json) | +| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_2 | cosine 100e | 13.1 | | 40.68 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area2/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210731_144648-aea9ecb6.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area2/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210731_144648.log.json) | +| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_3 | cosine 100e | 13.1 | | 69.38 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area3/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210801_154629-2ff50ee0.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area3/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210801_154629.log.json) | +| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_4 | cosine 100e | 13.1 | | 50.07 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area4/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210802_073551-dffab9cd.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area4/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210802_073551.log.json) | +| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_5 | cosine 100e | 13.1 | | 50.59 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area5/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210730_235824-f277e0c5.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area5/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210730_235824.log.json) | +| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_6 | cosine 100e | 13.1 | | 77.94 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area6/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210802_154317-e3511b32.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area6/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210802_154317.log.json) | +| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | 6-fold | | | | 59.43 | | + +**Notes:** + +- We use XYZ+Color+Normalized_XYZ as input in all the experiments on S3DIS datasets. +- `Area_5` Split means training the model on Area_1, 2, 3, 4, 6 and testing on Area_5. +- `6-fold` Split means the overall result of 6 different splits (Area_1, Area_2, Area_3, Area_4, Area_5 and Area_6 Splits). +- Users need to modify `train_area` and `test_area` in the S3DIS dataset's [config](./configs/_base_/datasets/s3dis_seg-3d-13class.py) to set the training and testing areas, respectively. + +## Indeterminism + +Since DGCNN testing adopts sliding patch inference which involves random point sampling, and the test script uses fixed random seeds while the random seeds of validation in training are not fixed, the test results may be slightly different from the results reported above. diff --git a/configs/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py b/configs/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py new file mode 100644 index 0000000000..6f1b5822a2 --- /dev/null +++ b/configs/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py @@ -0,0 +1,24 @@ +_base_ = [ + '../_base_/datasets/s3dis_seg-3d-13class.py', '../_base_/models/dgcnn.py', + '../_base_/schedules/seg_cosine_100e.py', '../_base_/default_runtime.py' +] + +# data settings +data = dict(samples_per_gpu=32) +evaluation = dict(interval=2) + +# model settings +model = dict( + backbone=dict(in_channels=9), # [xyz, rgb, normalized_xyz] + decode_head=dict( + num_classes=13, ignore_index=13, + loss_decode=dict(class_weight=None)), # S3DIS doesn't use class_weight + test_cfg=dict( + num_points=4096, + block_size=1.0, + sample_rate=0.5, + use_normalized_coord=True, + batch_size=24)) + +# runtime settings +checkpoint_config = dict(interval=2) diff --git a/configs/dgcnn/metafile.yml b/configs/dgcnn/metafile.yml new file mode 100644 index 0000000000..87ff9156bc --- /dev/null +++ b/configs/dgcnn/metafile.yml @@ -0,0 +1,24 @@ +Collections: + - Name: DGCNN + Metadata: + Training Techniques: + - SGD + Training Resources: 4x Titan XP GPUs + Architecture: + - DGCNN + Paper: https://arxiv.org/abs/1801.07829 + README: configs/dgcnn/README.md + +Models: + - Name: dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py + In Collection: DGCNN + Config: configs/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py + Metadata: + Training Data: S3DIS + Training Memory (GB): 13.3 + Results: + - Task: 3D Semantic Segmentation + Dataset: S3DIS + Metrics: + mIoU: 50.59 + Weights: https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area5/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210730_235824-f277e0c5.pth diff --git a/docs/model_zoo.md b/docs/model_zoo.md index 9192855959..3ebae29c58 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -77,3 +77,7 @@ Please refer to [ImVoxelNet](https://github.com/open-mmlab/mmdetection3d/blob/ma ### PAConv Please refer to [PAConv](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/paconv) for details. We provide PAConv baselines on S3DIS dataset. + +### DGCNN + +Please refer to [DGCNN](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/dgcnn) for details. We provide DGCNN baselines on S3DIS dataset. diff --git a/docs_zh-CN/model_zoo.md b/docs_zh-CN/model_zoo.md index d897bf9dfd..7a7589e0cd 100644 --- a/docs_zh-CN/model_zoo.md +++ b/docs_zh-CN/model_zoo.md @@ -75,3 +75,11 @@ ### ImVoxelNet 请参考 [ImVoxelNet](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/imvoxelnet) 获取更多细节,我们在 KITTI 数据集上给出了相应的结果。 + +### PAConv + +请参考 [PAConv](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/paconv) 获取更多细节,我们在 S3DIS 数据集上给出了相应的结果. + +### DGCNN + +请参考 [DGCNN](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/dgcnn) 获取更多细节,我们在 S3DIS 数据集上给出了相应的结果. diff --git a/mmdet3d/models/backbones/__init__.py b/mmdet3d/models/backbones/__init__.py index 0251a10456..26c432ccfc 100644 --- a/mmdet3d/models/backbones/__init__.py +++ b/mmdet3d/models/backbones/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt +from .dgcnn import DGCNNBackbone from .multi_backbone import MultiBackbone from .nostem_regnet import NoStemRegNet from .pointnet2_sa_msg import PointNet2SAMSG @@ -8,5 +9,6 @@ __all__ = [ 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet', - 'SECOND', 'PointNet2SASSG', 'PointNet2SAMSG', 'MultiBackbone' + 'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG', + 'MultiBackbone' ] diff --git a/mmdet3d/models/backbones/dgcnn.py b/mmdet3d/models/backbones/dgcnn.py new file mode 100644 index 0000000000..fe369890e0 --- /dev/null +++ b/mmdet3d/models/backbones/dgcnn.py @@ -0,0 +1,98 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.runner import BaseModule, auto_fp16 +from torch import nn as nn + +from mmdet3d.ops import DGCNNFAModule, DGCNNGFModule +from mmdet.models import BACKBONES + + +@BACKBONES.register_module() +class DGCNNBackbone(BaseModule): + """Backbone network for DGCNN. + + Args: + in_channels (int): Input channels of point cloud. + num_samples (tuple[int], optional): The number of samples for knn or + ball query in each graph feature (GF) module. + Defaults to (20, 20, 20). + knn_modes (tuple[str], optional): Mode of KNN of each knn module. + Defaults to ('D-KNN', 'F-KNN', 'F-KNN'). + radius (tuple[float], optional): Sampling radii of each GF module. + Defaults to (None, None, None). + gf_channels (tuple[tuple[int]], optional): Out channels of each mlp in + GF module. Defaults to ((64, 64), (64, 64), (64, )). + fa_channels (tuple[int], optional): Out channels of each mlp in FA + module. Defaults to (1024, ). + act_cfg (dict, optional): Config of activation layer. + Defaults to dict(type='ReLU'). + init_cfg (dict, optional): Initialization config. + Defaults to None. + """ + + def __init__(self, + in_channels, + num_samples=(20, 20, 20), + knn_modes=('D-KNN', 'F-KNN', 'F-KNN'), + radius=(None, None, None), + gf_channels=((64, 64), (64, 64), (64, )), + fa_channels=(1024, ), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.num_gf = len(gf_channels) + + assert len(num_samples) == len(knn_modes) == len(radius) == len( + gf_channels), 'Num_samples, knn_modes, radius and gf_channels \ + should have the same length.' + + self.GF_modules = nn.ModuleList() + gf_in_channel = in_channels * 2 + skip_channel_list = [gf_in_channel] # input channel list + + for gf_index in range(self.num_gf): + cur_gf_mlps = list(gf_channels[gf_index]) + cur_gf_mlps = [gf_in_channel] + cur_gf_mlps + gf_out_channel = cur_gf_mlps[-1] + + self.GF_modules.append( + DGCNNGFModule( + mlp_channels=cur_gf_mlps, + num_sample=num_samples[gf_index], + knn_mode=knn_modes[gf_index], + radius=radius[gf_index], + act_cfg=act_cfg)) + skip_channel_list.append(gf_out_channel) + gf_in_channel = gf_out_channel * 2 + + fa_in_channel = sum(skip_channel_list[1:]) + cur_fa_mlps = list(fa_channels) + cur_fa_mlps = [fa_in_channel] + cur_fa_mlps + + self.FA_module = DGCNNFAModule( + mlp_channels=cur_fa_mlps, act_cfg=act_cfg) + + @auto_fp16(apply_to=('points', )) + def forward(self, points): + """Forward pass. + + Args: + points (torch.Tensor): point coordinates with features, + with shape (B, N, in_channels). + + Returns: + dict[str, list[torch.Tensor]]: Outputs after graph feature (GF) and + feature aggregation (FA) modules. + + - gf_points (list[torch.Tensor]): Outputs after each GF module. + - fa_points (torch.Tensor): Outputs after FA module. + """ + gf_points = [points] + + for i in range(self.num_gf): + cur_points = self.GF_modules[i](gf_points[i]) + gf_points.append(cur_points) + + fa_points = self.FA_module(gf_points) + + out = dict(gf_points=gf_points, fa_points=fa_points) + return out diff --git a/mmdet3d/models/decode_heads/__init__.py b/mmdet3d/models/decode_heads/__init__.py index e17d91da0c..2e86c7c8a9 100644 --- a/mmdet3d/models/decode_heads/__init__.py +++ b/mmdet3d/models/decode_heads/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .dgcnn_head import DGCNNHead from .paconv_head import PAConvHead from .pointnet2_head import PointNet2Head -__all__ = ['PointNet2Head', 'PAConvHead'] +__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead'] diff --git a/mmdet3d/models/decode_heads/dgcnn_head.py b/mmdet3d/models/decode_heads/dgcnn_head.py new file mode 100644 index 0000000000..4d4e1887bc --- /dev/null +++ b/mmdet3d/models/decode_heads/dgcnn_head.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn.bricks import ConvModule + +from mmdet3d.ops import DGCNNFPModule +from mmdet.models import HEADS +from .decode_head import Base3DDecodeHead + + +@HEADS.register_module() +class DGCNNHead(Base3DDecodeHead): + r"""DGCNN decoder head. + + Decoder head used in `DGCNN `_. + Refer to the + `reimplementation code `_. + + Args: + fp_channels (tuple[int], optional): Tuple of mlp channels in feature + propagation (FP) modules. Defaults to (1216, 512). + """ + + def __init__(self, fp_channels=(1216, 512), **kwargs): + super(DGCNNHead, self).__init__(**kwargs) + + self.FP_module = DGCNNFPModule( + mlp_channels=fp_channels, act_cfg=self.act_cfg) + + # https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L40 + self.pre_seg_conv = ConvModule( + fp_channels[-1], + self.channels, + kernel_size=1, + bias=False, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def _extract_input(self, feat_dict): + """Extract inputs from features dictionary. + + Args: + feat_dict (dict): Feature dict from backbone. + + Returns: + torch.Tensor: points for decoder. + """ + fa_points = feat_dict['fa_points'] + + return fa_points + + def forward(self, feat_dict): + """Forward pass. + + Args: + feat_dict (dict): Feature dict from backbone. + + Returns: + torch.Tensor: Segmentation map of shape [B, num_classes, N]. + """ + fa_points = self._extract_input(feat_dict) + + fp_points = self.FP_module(fa_points) + fp_points = fp_points.transpose(1, 2).contiguous() + output = self.pre_seg_conv(fp_points) + output = self.cls_seg(output) + + return output diff --git a/mmdet3d/ops/__init__.py b/mmdet3d/ops/__init__.py index 38e2ea7367..1dafa428ac 100644 --- a/mmdet3d/ops/__init__.py +++ b/mmdet3d/ops/__init__.py @@ -4,6 +4,7 @@ sigmoid_focal_loss) from .ball_query import ball_query +from .dgcnn_modules import DGCNNFAModule, DGCNNFPModule, DGCNNGFModule from .furthest_point_sample import (Points_Sampler, furthest_point_sample, furthest_point_sample_with_dist) from .gather_points import gather_points @@ -34,8 +35,9 @@ 'furthest_point_sample_with_dist', 'three_interpolate', 'three_nn', 'gather_points', 'grouping_operation', 'group_points', 'GroupAll', 'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule', - 'points_in_boxes_all', 'get_compiler_version', 'assign_score_withk', - 'get_compiling_cuda_version', 'Points_Sampler', 'build_sa_module', - 'PAConv', 'PAConvCUDA', 'PAConvSAModuleMSG', 'PAConvSAModule', - 'PAConvCUDASAModule', 'PAConvCUDASAModuleMSG' + 'DGCNNFPModule', 'DGCNNGFModule', 'DGCNNFAModule', 'points_in_boxes_all', + 'get_compiler_version', 'assign_score_withk', 'get_compiling_cuda_version', + 'Points_Sampler', 'build_sa_module', 'PAConv', 'PAConvCUDA', + 'PAConvSAModuleMSG', 'PAConvSAModule', 'PAConvCUDASAModule', + 'PAConvCUDASAModuleMSG' ] diff --git a/mmdet3d/ops/dgcnn_modules/__init__.py b/mmdet3d/ops/dgcnn_modules/__init__.py new file mode 100644 index 0000000000..67beb0907f --- /dev/null +++ b/mmdet3d/ops/dgcnn_modules/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .dgcnn_fa_module import DGCNNFAModule +from .dgcnn_fp_module import DGCNNFPModule +from .dgcnn_gf_module import DGCNNGFModule + +__all__ = ['DGCNNFAModule', 'DGCNNFPModule', 'DGCNNGFModule'] diff --git a/mmdet3d/ops/dgcnn_modules/dgcnn_fa_module.py b/mmdet3d/ops/dgcnn_modules/dgcnn_fa_module.py new file mode 100644 index 0000000000..b0975e691b --- /dev/null +++ b/mmdet3d/ops/dgcnn_modules/dgcnn_fa_module.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule, force_fp32 +from torch import nn as nn + + +class DGCNNFAModule(BaseModule): + """Point feature aggregation module used in DGCNN. + + Aggregate all the features of points. + + Args: + mlp_channels (list[int]): List of mlp channels. + norm_cfg (dict, optional): Type of normalization method. + Defaults to dict(type='BN1d'). + act_cfg (dict, optional): Type of activation method. + Defaults to dict(type='ReLU'). + init_cfg (dict, optional): Initialization config. Defaults to None. + """ + + def __init__(self, + mlp_channels, + norm_cfg=dict(type='BN1d'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.fp16_enabled = False + self.mlps = nn.Sequential() + for i in range(len(mlp_channels) - 1): + self.mlps.add_module( + f'layer{i}', + ConvModule( + mlp_channels[i], + mlp_channels[i + 1], + kernel_size=(1, ), + stride=(1, ), + conv_cfg=dict(type='Conv1d'), + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + @force_fp32() + def forward(self, points): + """forward. + + Args: + points (List[Tensor]): tensor of the features to be aggregated. + + Returns: + Tensor: (B, N, M) M = mlp[-1], tensor of the output points. + """ + + if len(points) > 1: + new_points = torch.cat(points[1:], dim=-1) + new_points = new_points.transpose(1, 2).contiguous() # (B, C, N) + new_points_copy = new_points + + new_points = self.mlps(new_points) + + new_fa_points = new_points.max(dim=-1, keepdim=True)[0] + new_fa_points = new_fa_points.repeat(1, 1, new_points.shape[-1]) + + new_points = torch.cat([new_fa_points, new_points_copy], dim=1) + new_points = new_points.transpose(1, 2).contiguous() + else: + new_points = points + + return new_points diff --git a/mmdet3d/ops/dgcnn_modules/dgcnn_fp_module.py b/mmdet3d/ops/dgcnn_modules/dgcnn_fp_module.py new file mode 100644 index 0000000000..c871721bc1 --- /dev/null +++ b/mmdet3d/ops/dgcnn_modules/dgcnn_fp_module.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv.cnn import ConvModule +from mmcv.runner import BaseModule, force_fp32 +from torch import nn as nn + + +class DGCNNFPModule(BaseModule): + """Point feature propagation module used in DGCNN. + + Propagate the features from one set to another. + + Args: + mlp_channels (list[int]): List of mlp channels. + norm_cfg (dict, optional): Type of activation method. + Defaults to dict(type='BN1d'). + act_cfg (dict, optional): Type of activation method. + Defaults to dict(type='ReLU'). + init_cfg (dict, optional): Initialization config. Defaults to None. + """ + + def __init__(self, + mlp_channels, + norm_cfg=dict(type='BN1d'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.fp16_enabled = False + self.mlps = nn.Sequential() + for i in range(len(mlp_channels) - 1): + self.mlps.add_module( + f'layer{i}', + ConvModule( + mlp_channels[i], + mlp_channels[i + 1], + kernel_size=(1, ), + stride=(1, ), + conv_cfg=dict(type='Conv1d'), + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + @force_fp32() + def forward(self, points): + """forward. + + Args: + points (Tensor): (B, N, C) tensor of the input points. + + Returns: + Tensor: (B, N, M) M = mlp[-1], tensor of the new points. + """ + + if points is not None: + new_points = points.transpose(1, 2).contiguous() # (B, C, N) + new_points = self.mlps(new_points) + new_points = new_points.transpose(1, 2).contiguous() + else: + new_points = points + + return new_points diff --git a/mmdet3d/ops/dgcnn_modules/dgcnn_gf_module.py b/mmdet3d/ops/dgcnn_modules/dgcnn_gf_module.py new file mode 100644 index 0000000000..e317ccd086 --- /dev/null +++ b/mmdet3d/ops/dgcnn_modules/dgcnn_gf_module.py @@ -0,0 +1,222 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import ConvModule +from torch import nn as nn +from torch.nn import functional as F + +from ..group_points import GroupAll, QueryAndGroup, grouping_operation + + +class BaseDGCNNGFModule(nn.Module): + """Base module for point graph feature module used in DGCNN. + + Args: + radii (list[float]): List of radius in each knn or ball query. + sample_nums (list[int]): Number of samples in each knn or ball query. + mlp_channels (list[list[int]]): Specify of the dgcnn before + the global pooling for each graph feature module. + knn_modes (list[str], optional): Type of KNN method, valid mode + ['F-KNN', 'D-KNN'], Defaults to ['F-KNN']. + dilated_group (bool, optional): Whether to use dilated ball query. + Defaults to False. + use_xyz (bool, optional): Whether to use xyz as point features. + Defaults to True. + pool_mode (str, optional): Type of pooling method. Defaults to 'max'. + normalize_xyz (bool, optional): If ball query, whether to normalize + local XYZ with radius. Defaults to False. + grouper_return_grouped_xyz (bool, optional): Whether to return grouped + xyz in `QueryAndGroup`. Defaults to False. + grouper_return_grouped_idx (bool, optional): Whether to return grouped + idx in `QueryAndGroup`. Defaults to False. + """ + + def __init__(self, + radii, + sample_nums, + mlp_channels, + knn_modes=['F-KNN'], + dilated_group=False, + use_xyz=True, + pool_mode='max', + normalize_xyz=False, + grouper_return_grouped_xyz=False, + grouper_return_grouped_idx=False): + super(BaseDGCNNGFModule, self).__init__() + + assert len(sample_nums) == len( + mlp_channels + ), 'Num_samples and mlp_channels should have the same length.' + assert pool_mode in ['max', 'avg' + ], "Pool_mode should be one of ['max', 'avg']." + assert isinstance(knn_modes, list) or isinstance( + knn_modes, tuple), 'The type of knn_modes should be list or tuple.' + + if isinstance(mlp_channels, tuple): + mlp_channels = list(map(list, mlp_channels)) + self.mlp_channels = mlp_channels + + self.pool_mode = pool_mode + self.groupers = nn.ModuleList() + self.mlps = nn.ModuleList() + self.knn_modes = knn_modes + + for i in range(len(sample_nums)): + sample_num = sample_nums[i] + if sample_num is not None: + if self.knn_modes[i] == 'D-KNN': + grouper = QueryAndGroup( + radii[i], + sample_num, + use_xyz=use_xyz, + normalize_xyz=normalize_xyz, + return_grouped_xyz=grouper_return_grouped_xyz, + return_grouped_idx=True) + else: + grouper = QueryAndGroup( + radii[i], + sample_num, + use_xyz=use_xyz, + normalize_xyz=normalize_xyz, + return_grouped_xyz=grouper_return_grouped_xyz, + return_grouped_idx=grouper_return_grouped_idx) + else: + grouper = GroupAll(use_xyz) + self.groupers.append(grouper) + + def _pool_features(self, features): + """Perform feature aggregation using pooling operation. + + Args: + features (torch.Tensor): (B, C, N, K) + Features of locally grouped points before pooling. + + Returns: + torch.Tensor: (B, C, N) + Pooled features aggregating local information. + """ + if self.pool_mode == 'max': + # (B, C, N, 1) + new_features = F.max_pool2d( + features, kernel_size=[1, features.size(3)]) + elif self.pool_mode == 'avg': + # (B, C, N, 1) + new_features = F.avg_pool2d( + features, kernel_size=[1, features.size(3)]) + else: + raise NotImplementedError + + return new_features.squeeze(-1).contiguous() + + def forward(self, points): + """forward. + + Args: + points (Tensor): (B, N, C) input points. + + Returns: + List[Tensor]: (B, N, C1) new points generated from each graph + feature module. + """ + new_points_list = [points] + + for i in range(len(self.groupers)): + + new_points = new_points_list[i] + new_points_trans = new_points.transpose( + 1, 2).contiguous() # (B, C, N) + + if self.knn_modes[i] == 'D-KNN': + # (B, N, C) -> (B, N, K) + idx = self.groupers[i](new_points[..., -3:].contiguous(), + new_points[..., -3:].contiguous())[-1] + + grouped_results = grouping_operation( + new_points_trans, idx) # (B, C, N) -> (B, C, N, K) + grouped_results -= new_points_trans.unsqueeze(-1) + else: + grouped_results = self.groupers[i]( + new_points, new_points) # (B, N, C) -> (B, C, N, K) + + new_points = new_points_trans.unsqueeze(-1).repeat( + 1, 1, 1, grouped_results.shape[-1]) + new_points = torch.cat([grouped_results, new_points], dim=1) + + # (B, mlp[-1], N, K) + new_points = self.mlps[i](new_points) + + # (B, mlp[-1], N) + new_points = self._pool_features(new_points) + new_points = new_points.transpose(1, 2).contiguous() + new_points_list.append(new_points) + + return new_points + + +class DGCNNGFModule(BaseDGCNNGFModule): + """Point graph feature module used in DGCNN. + + Args: + mlp_channels (list[int]): Specify of the dgcnn before + the global pooling for each graph feature module. + num_sample (int, optional): Number of samples in each knn or ball + query. Defaults to None. + knn_mode (str, optional): Type of KNN method, valid mode + ['F-KNN', 'D-KNN']. Defaults to 'F-KNN'. + radius (float, optional): Radius to group with. + Defaults to None. + dilated_group (bool, optional): Whether to use dilated ball query. + Defaults to False. + norm_cfg (dict, optional): Type of normalization method. + Defaults to dict(type='BN2d'). + act_cfg (dict, optional): Type of activation method. + Defaults to dict(type='ReLU'). + use_xyz (bool, optional): Whether to use xyz as point features. + Defaults to True. + pool_mode (str, optional): Type of pooling method. + Defaults to 'max'. + normalize_xyz (bool, optional): If ball query, whether to normalize + local XYZ with radius. Defaults to False. + bias (bool | str, optional): If specified as `auto`, it will be decided + by the norm_cfg. Bias will be set as True if `norm_cfg` is None, + otherwise False. Defaults to 'auto'. + """ + + def __init__(self, + mlp_channels, + num_sample=None, + knn_mode='F-KNN', + radius=None, + dilated_group=False, + norm_cfg=dict(type='BN2d'), + act_cfg=dict(type='ReLU'), + use_xyz=True, + pool_mode='max', + normalize_xyz=False, + bias='auto'): + super(DGCNNGFModule, self).__init__( + mlp_channels=[mlp_channels], + sample_nums=[num_sample], + knn_modes=[knn_mode], + radii=[radius], + use_xyz=use_xyz, + pool_mode=pool_mode, + normalize_xyz=normalize_xyz, + dilated_group=dilated_group) + + for i in range(len(self.mlp_channels)): + mlp_channel = self.mlp_channels[i] + + mlp = nn.Sequential() + for i in range(len(mlp_channel) - 1): + mlp.add_module( + f'layer{i}', + ConvModule( + mlp_channel[i], + mlp_channel[i + 1], + kernel_size=(1, 1), + stride=(1, 1), + conv_cfg=dict(type='Conv2d'), + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=bias)) + self.mlps.append(mlp) diff --git a/tests/test_models/test_backbones.py b/tests/test_models/test_backbones.py index 6a3d2cb422..5c9f5edfe8 100644 --- a/tests/test_models/test_backbones.py +++ b/tests/test_models/test_backbones.py @@ -297,3 +297,36 @@ def test_pointnet2_sa_msg(): assert sa_indices[2].shape == torch.Size([1, 256]) assert sa_indices[3].shape == torch.Size([1, 64]) assert sa_indices[4].shape == torch.Size([1, 16]) + + +def test_dgcnn_gf(): + if not torch.cuda.is_available(): + pytest.skip() + + # DGCNNGF used in segmentation + cfg = dict( + type='DGCNNBackbone', + in_channels=6, + num_samples=(20, 20, 20), + knn_modes=['D-KNN', 'F-KNN', 'F-KNN'], + radius=(None, None, None), + gf_channels=((64, 64), (64, 64), (64, )), + fa_channels=(1024, ), + act_cfg=dict(type='ReLU')) + + self = build_backbone(cfg) + self.cuda() + + xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', dtype=np.float32) + xyz = torch.from_numpy(xyz).view(1, -1, 6).cuda() # (B, N, 6) + # test forward + ret_dict = self(xyz) + gf_points = ret_dict['gf_points'] + fa_points = ret_dict['fa_points'] + + assert len(gf_points) == 4 + assert gf_points[0].shape == torch.Size([1, 100, 6]) + assert gf_points[1].shape == torch.Size([1, 100, 64]) + assert gf_points[2].shape == torch.Size([1, 100, 64]) + assert gf_points[3].shape == torch.Size([1, 100, 64]) + assert fa_points.shape == torch.Size([1, 100, 1216]) diff --git a/tests/test_models/test_common_modules/test_dgcnn_modules.py b/tests/test_models/test_common_modules/test_dgcnn_modules.py new file mode 100644 index 0000000000..031971b459 --- /dev/null +++ b/tests/test_models/test_common_modules/test_dgcnn_modules.py @@ -0,0 +1,92 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest +import torch + + +def test_dgcnn_gf_module(): + if not torch.cuda.is_available(): + pytest.skip() + from mmdet3d.ops import DGCNNGFModule + + self = DGCNNGFModule( + mlp_channels=[18, 64, 64], + num_sample=20, + knn_mod='D-KNN', + radius=None, + norm_cfg=dict(type='BN2d'), + act_cfg=dict(type='ReLU'), + pool_mod='max').cuda() + + assert self.mlps[0].layer0.conv.in_channels == 18 + assert self.mlps[0].layer0.conv.out_channels == 64 + + xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32) + + # (B, N, C) + xyz = torch.from_numpy(xyz).view(1, -1, 3).cuda() + points = xyz.repeat([1, 1, 3]) + + # test forward + new_points = self(points) + + assert new_points.shape == torch.Size([1, 200, 64]) + + # test F-KNN mod + self = DGCNNGFModule( + mlp_channels=[6, 64, 64], + num_sample=20, + knn_mod='F-KNN', + radius=None, + norm_cfg=dict(type='BN2d'), + act_cfg=dict(type='ReLU'), + pool_mod='max').cuda() + + # test forward + new_points = self(xyz) + assert new_points.shape == torch.Size([1, 200, 64]) + + # test ball query + self = DGCNNGFModule( + mlp_channels=[6, 64, 64], + num_sample=20, + knn_mod='F-KNN', + radius=0.2, + norm_cfg=dict(type='BN2d'), + act_cfg=dict(type='ReLU'), + pool_mod='max').cuda() + + +def test_dgcnn_fa_module(): + if not torch.cuda.is_available(): + pytest.skip() + from mmdet3d.ops import DGCNNFAModule + + self = DGCNNFAModule(mlp_channels=[24, 16]).cuda() + assert self.mlps.layer0.conv.in_channels == 24 + assert self.mlps.layer0.conv.out_channels == 16 + + points = [torch.rand(1, 200, 12).float().cuda() for _ in range(3)] + + fa_points = self(points) + assert fa_points.shape == torch.Size([1, 200, 40]) + + +def test_dgcnn_fp_module(): + if not torch.cuda.is_available(): + pytest.skip() + from mmdet3d.ops import DGCNNFPModule + + self = DGCNNFPModule(mlp_channels=[24, 16]).cuda() + assert self.mlps.layer0.conv.in_channels == 24 + assert self.mlps.layer0.conv.out_channels == 16 + + xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', + np.float32).reshape((-1, 6)) + + # (B, N, 3) + xyz = torch.from_numpy(xyz).view(1, -1, 3).cuda() + points = xyz.repeat([1, 1, 8]).cuda() + + fp_points = self(points) + assert fp_points.shape == torch.Size([1, 200, 16]) diff --git a/tests/test_models/test_heads/test_dgcnn_decode_head.py b/tests/test_models/test_heads/test_dgcnn_decode_head.py new file mode 100644 index 0000000000..6d1f149530 --- /dev/null +++ b/tests/test_models/test_heads/test_dgcnn_decode_head.py @@ -0,0 +1,68 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest +import torch +from mmcv.cnn.bricks import ConvModule + +from mmdet3d.models.builder import build_head + + +def test_dgcnn_decode_head_loss(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + dgcnn_decode_head_cfg = dict( + type='DGCNNHead', + fp_channels=(1024, 512), + channels=256, + num_classes=13, + dropout_ratio=0.5, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + act_cfg=dict(type='LeakyReLU', negative_slope=0.2), + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, + loss_weight=1.0), + ignore_index=13) + + self = build_head(dgcnn_decode_head_cfg) + self.cuda() + assert isinstance(self.conv_seg, torch.nn.Conv1d) + assert self.conv_seg.in_channels == 256 + assert self.conv_seg.out_channels == 13 + assert self.conv_seg.kernel_size == (1, ) + assert isinstance(self.pre_seg_conv, ConvModule) + assert isinstance(self.pre_seg_conv.conv, torch.nn.Conv1d) + assert self.pre_seg_conv.conv.in_channels == 512 + assert self.pre_seg_conv.conv.out_channels == 256 + assert self.pre_seg_conv.conv.kernel_size == (1, ) + assert isinstance(self.pre_seg_conv.bn, torch.nn.BatchNorm1d) + assert self.pre_seg_conv.bn.num_features == 256 + + # test forward + fa_points = torch.rand(2, 4096, 1024).float().cuda() + input_dict = dict(fa_points=fa_points) + seg_logits = self(input_dict) + assert seg_logits.shape == torch.Size([2, 13, 4096]) + + # test loss + pts_semantic_mask = torch.randint(0, 13, (2, 4096)).long().cuda() + losses = self.losses(seg_logits, pts_semantic_mask) + assert losses['loss_sem_seg'].item() > 0 + + # test loss with ignore_index + ignore_index_mask = torch.ones_like(pts_semantic_mask) * 13 + losses = self.losses(seg_logits, ignore_index_mask) + assert losses['loss_sem_seg'].item() == 0 + + # test loss with class_weight + dgcnn_decode_head_cfg['loss_decode'] = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + class_weight=np.random.rand(13), + loss_weight=1.0) + self = build_head(dgcnn_decode_head_cfg) + self.cuda() + losses = self.losses(seg_logits, pts_semantic_mask) + assert losses['loss_sem_seg'].item() > 0 diff --git a/tests/test_models/test_segmentors.py b/tests/test_models/test_segmentors.py index faff3f9515..0974f9f507 100644 --- a/tests/test_models/test_segmentors.py +++ b/tests/test_models/test_segmentors.py @@ -304,3 +304,48 @@ def test_paconv_cuda_ssg(): results = self.forward(return_loss=False, **data_dict) assert results[0]['semantic_mask'].shape == torch.Size([200]) assert results[1]['semantic_mask'].shape == torch.Size([100]) + + +def test_dgcnn(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + + set_random_seed(0, True) + dgcnn_cfg = _get_segmentor_cfg( + 'dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py') + dgcnn_cfg.test_cfg.num_points = 32 + self = build_segmentor(dgcnn_cfg).cuda() + points = [torch.rand(4096, 9).float().cuda() for _ in range(2)] + img_metas = [dict(), dict()] + gt_masks = [torch.randint(0, 13, (4096, )).long().cuda() for _ in range(2)] + + # test forward_train + losses = self.forward_train(points, img_metas, gt_masks) + assert losses['decode.loss_sem_seg'].item() >= 0 + + # test loss with ignore_index + ignore_masks = [torch.ones_like(gt_masks[0]) * 13 for _ in range(2)] + losses = self.forward_train(points, img_metas, ignore_masks) + assert losses['decode.loss_sem_seg'].item() == 0 + + # test simple_test + self.eval() + with torch.no_grad(): + scene_points = [ + torch.randn(500, 6).float().cuda() * 3.0, + torch.randn(200, 6).float().cuda() * 2.5 + ] + results = self.simple_test(scene_points, img_metas) + assert results[0]['semantic_mask'].shape == torch.Size([500]) + assert results[1]['semantic_mask'].shape == torch.Size([200]) + + # test aug_test + with torch.no_grad(): + scene_points = [ + torch.randn(2, 500, 6).float().cuda() * 3.0, + torch.randn(2, 200, 6).float().cuda() * 2.5 + ] + img_metas = [[dict(), dict()], [dict(), dict()]] + results = self.aug_test(scene_points, img_metas) + assert results[0]['semantic_mask'].shape == torch.Size([500]) + assert results[1]['semantic_mask'].shape == torch.Size([200])