From dfcf5428cb378888486d607f46de6725a53474ee Mon Sep 17 00:00:00 2001 From: xizaoqu <45515569+xizaoqu@users.noreply.github.com> Date: Mon, 20 Mar 2023 09:54:36 +0800 Subject: [PATCH] [Feature] Cylinder3d segmentor (#2344) * update * add cylinder3d_backbone * add test segmentor * add cfg * add test backbone * rename test cylinder3d backbone * midway * update, pass validation * fix test * update cfg --- configs/_base_/models/cylinder3d.py | 41 ++ .../cylinder3d_4xb2_3x_semantickitti.py | 37 ++ mmdet3d/datasets/seg3d_dataset.py | 4 +- mmdet3d/models/backbones/__init__.py | 3 +- mmdet3d/models/backbones/cylinder3d.py | 110 +++++ .../models/decode_heads/cylinder3d_head.py | 6 +- mmdet3d/models/layers/sparse_block.py | 382 +++++++++++++++++- mmdet3d/models/segmentors/__init__.py | 3 +- mmdet3d/models/segmentors/cylinder3d.py | 142 +++++++ .../test_cylinder3d_backbone.py | 32 ++ .../test_segmentors/test_cylinder3d.py | 42 ++ 11 files changed, 790 insertions(+), 12 deletions(-) create mode 100644 configs/_base_/models/cylinder3d.py create mode 100644 configs/cylinder3d/cylinder3d_4xb2_3x_semantickitti.py create mode 100644 mmdet3d/models/backbones/cylinder3d.py create mode 100644 mmdet3d/models/segmentors/cylinder3d.py create mode 100644 tests/test_models/test_backbones/test_cylinder3d_backbone.py create mode 100644 tests/test_models/test_segmentors/test_cylinder3d.py diff --git a/configs/_base_/models/cylinder3d.py b/configs/_base_/models/cylinder3d.py new file mode 100644 index 0000000000..02e8323363 --- /dev/null +++ b/configs/_base_/models/cylinder3d.py @@ -0,0 +1,41 @@ +grid_shape = [480, 360, 32] +model = dict( + type='Cylinder3D', + data_preprocessor=dict( + type='Det3DDataPreprocessor', + voxel=True, + voxel_type='cylindrical', + voxel_layer=dict( + grid_shape=grid_shape, + point_cloud_range=[0, -3.14159265359, -4, 50, 3.14159265359, 2], + max_num_points=-1, + max_voxels=-1, + ), + ), + voxel_encoder=dict( + type='SegVFE', + feat_channels=[64, 128, 256, 256], + in_channels=6, + with_voxel_center=True, + feat_compression=16, + return_point_feats=False), + backbone=dict( + type='Asymm3DSpconv', + grid_size=grid_shape, + input_channels=16, + base_channels=32, + norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.1)), + decode_head=dict( + type='Cylinder3DHead', + channels=128, + num_classes=20, + loss_ce=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, + loss_weight=1.0), + loss_lovasz=dict(type='LovaszLoss', loss_weight=1.0, reduction='none'), + ), + train_cfg=None, + test_cfg=dict(mode='whole'), +) diff --git a/configs/cylinder3d/cylinder3d_4xb2_3x_semantickitti.py b/configs/cylinder3d/cylinder3d_4xb2_3x_semantickitti.py new file mode 100644 index 0000000000..7237bab014 --- /dev/null +++ b/configs/cylinder3d/cylinder3d_4xb2_3x_semantickitti.py @@ -0,0 +1,37 @@ +_base_ = [ + '../_base_/datasets/semantickitti.py', '../_base_/models/cylinder3d.py', + '../_base_/default_runtime.py' +] + +# optimizer +# This schedule is mainly used by models on nuScenes dataset +lr = 0.001 +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01)) + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=36, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, + end=1000), + dict( + type='MultiStepLR', + begin=0, + end=36, + by_epoch=True, + milestones=[30], + gamma=0.1) +] + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (8 GPUs) x (4 samples per GPU). +# auto_scale_lr = dict(enable=False, base_batch_size=32) + +default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5)) diff --git a/mmdet3d/datasets/seg3d_dataset.py b/mmdet3d/datasets/seg3d_dataset.py index 803a1b4d2f..5fd8a35ac6 100644 --- a/mmdet3d/datasets/seg3d_dataset.py +++ b/mmdet3d/datasets/seg3d_dataset.py @@ -255,8 +255,8 @@ def parse_data_info(self, info: dict) -> dict: osp.join( self.data_prefix.get('pts', ''), info['lidar_points']['lidar_path']) - - info['num_pts_feats'] = info['lidar_points']['num_pts_feats'] + if 'num_pts_feats' in info['lidar_points']: + info['num_pts_feats'] = info['lidar_points']['num_pts_feats'] info['lidar_path'] = info['lidar_points']['lidar_path'] if self.modality['use_camera']: diff --git a/mmdet3d/models/backbones/__init__.py b/mmdet3d/models/backbones/__init__.py index 009a06947a..bd7dc04ad4 100644 --- a/mmdet3d/models/backbones/__init__.py +++ b/mmdet3d/models/backbones/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt +from .cylinder3d import Asymm3DSpconv from .dgcnn import DGCNNBackbone from .dla import DLANet from .mink_resnet import MinkResNet @@ -13,5 +14,5 @@ __all__ = [ 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet', 'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG', - 'MultiBackbone', 'DLANet', 'MinkResNet' + 'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv' ] diff --git a/mmdet3d/models/backbones/cylinder3d.py b/mmdet3d/models/backbones/cylinder3d.py new file mode 100644 index 0000000000..4d1440d0d0 --- /dev/null +++ b/mmdet3d/models/backbones/cylinder3d.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +r"""Modified from Cylinder3D. + +Please refer to `Cylinder3D github page +`_ for details +""" + +from typing import List + +import numpy as np +import torch +from mmcv.ops import SparseConvTensor +from mmengine.model import BaseModule + +from mmdet3d.models.layers.sparse_block import (AsymmeDownBlock, AsymmeUpBlock, + AsymmResBlock, DDCMBlock) +from mmdet3d.registry import MODELS +from mmdet3d.utils import ConfigType + + +@MODELS.register_module() +class Asymm3DSpconv(BaseModule): + """Asymmetrical 3D convolution networks. + + Args: + grid_size (int): Size of voxel grids. + input_channels (int): Input channels of the block. + base_channels (int): Initial size of feature channels before + feeding into Encoder-Decoder structure. Defaults to 16. + backbone_depth (int): The depth of backbone. The backbone contains + downblocks and upblocks with the number of backbone_depth. + height_pooing (List[bool]): List indicating which downblocks perform + height pooling. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization + layer. Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01)). + init_cfg (dict, optional): Initialization config. + Defaults to None. + """ + + def __init__(self, + grid_size: int, + input_channels: int, + base_channels: int = 16, + backbone_depth: int = 4, + height_pooing: List[bool] = [True, True, False, False], + norm_cfg: ConfigType = dict( + type='BN1d', eps=1e-3, momentum=0.01), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + self.grid_size = grid_size + self.backbone_depth = backbone_depth + self.down_context = AsymmResBlock( + input_channels, base_channels, indice_key='pre', norm_cfg=norm_cfg) + + self.down_block_list = torch.nn.ModuleList() + self.up_block_list = torch.nn.ModuleList() + for i in range(self.backbone_depth): + self.down_block_list.append( + AsymmeDownBlock( + 2**i * base_channels, + 2**(i + 1) * base_channels, + height_pooling=height_pooing[i], + indice_key='down' + str(i), + norm_cfg=norm_cfg)) + if i == self.backbone_depth - 1: + self.up_block_list.append( + AsymmeUpBlock( + 2**(i + 1) * base_channels, + 2**(i + 1) * base_channels, + up_key='down' + str(i), + indice_key='up' + str(self.backbone_depth - 1 - i), + norm_cfg=norm_cfg)) + else: + self.up_block_list.append( + AsymmeUpBlock( + 2**(i + 2) * base_channels, + 2**(i + 1) * base_channels, + up_key='down' + str(i), + indice_key='up' + str(self.backbone_depth - 1 - i), + norm_cfg=norm_cfg)) + + self.ddcm = DDCMBlock( + 2 * base_channels, + 2 * base_channels, + indice_key='ddcm', + norm_cfg=norm_cfg) + + def forward(self, voxel_features: torch.Tensor, coors: torch.Tensor, + batch_size: int) -> SparseConvTensor: + """Forward pass.""" + coors = coors.int() + ret = SparseConvTensor(voxel_features, coors, np.array(self.grid_size), + batch_size) + ret = self.down_context(ret) + + down_skip_list = [] + down_pool = ret + for i in range(self.backbone_depth): + down_pool, down_skip = self.down_block_list[i](down_pool) + down_skip_list.append(down_skip) + + up = down_pool + for i in range(self.backbone_depth - 1, -1, -1): + up = self.up_block_list[i](up, down_skip_list[i]) + + ddcm = self.ddcm(up) + ddcm.features = torch.cat((ddcm.features, up.features), 1) + + return ddcm diff --git a/mmdet3d/models/decode_heads/cylinder3d_head.py b/mmdet3d/models/decode_heads/cylinder3d_head.py index dd13fd4dd3..26c621c5ba 100644 --- a/mmdet3d/models/decode_heads/cylinder3d_head.py +++ b/mmdet3d/models/decode_heads/cylinder3d_head.py @@ -39,7 +39,7 @@ class Cylinder3DHead(Base3DDecodeHead): conv_seg_kernel_size (int): The kernel size used in conv_seg. Defaults to 3. ignore_index (int): The label index to be ignored. When using masked - BCE loss, ignore_index should be set to None. Defaults to 0. + BCE loss, ignore_index should be set to None. Defaults to 19. init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`], optional): Initialization config dict. Defaults to None. """ @@ -59,7 +59,7 @@ def __init__(self, loss_lovasz: ConfigType = dict( type='LovaszLoss', loss_weight=1.0), conv_seg_kernel_size: int = 3, - ignore_index: int = 0, + ignore_index: int = 19, init_cfg: OptMultiConfig = None) -> None: super(Cylinder3DHead, self).__init__( channels=channels, @@ -116,8 +116,6 @@ def loss_by_feat(self, seg_logit: SparseConvTensor, loss = dict() loss['loss_ce'] = self.loss_ce( seg_logit_feat, seg_label, ignore_index=self.ignore_index) - seg_logit_feat = seg_logit_feat.permute(1, 0)[None, :, :, - None] # pseudo BCHW loss['loss_lovasz'] = self.loss_lovasz( seg_logit_feat, seg_label, ignore_index=self.ignore_index) diff --git a/mmdet3d/models/layers/sparse_block.py b/mmdet3d/models/layers/sparse_block.py index 14fc4deeda..cbfa8ae8fb 100644 --- a/mmdet3d/models/layers/sparse_block.py +++ b/mmdet3d/models/layers/sparse_block.py @@ -1,17 +1,20 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Tuple, Union +from typing import Optional, Tuple, Union -from mmcv.cnn import build_conv_layer, build_norm_layer +from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer from mmdet.models.backbones.resnet import BasicBlock, Bottleneck from torch import nn -from mmdet3d.utils import OptConfigType +from mmdet3d.utils import ConfigType, OptConfigType from .spconv import IS_SPCONV2_AVAILABLE if IS_SPCONV2_AVAILABLE: from spconv.pytorch import SparseConvTensor, SparseModule, SparseSequential else: - from mmcv.ops import SparseConvTensor, SparseModule, SparseSequential + from mmcv.ops import (SparseConvTensor, SparseModule, SparseSequential, + SparseConv3d, SparseInverseConv3d, SubMConv3d) + +from mmengine.model import BaseModule def replace_feature(out: SparseConvTensor, @@ -207,3 +210,374 @@ def make_sparse_convmodule( layers = SparseSequential(*layers) return layers + + +# The following module only supports spconv_v1 +class AsymmResBlock(BaseModule): + """Asymmetrical Residual Block. + + Args: + in_channels (int): Input channels of the block. + out_channels (int): Output channels of the block. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for + normalization layer. + act_cfg (:obj:`ConfigDict` or dict): Config dict of activation layers. + Defaults to dict(type='LeakyReLU'). + indice_key (str, optional): Name of indice tables. Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + norm_cfg: ConfigType, + act_cfg: ConfigType = dict(type='LeakyReLU'), + indice_key: Optional[str] = None): + super().__init__() + + self.conv0_0 = SubMConv3d( + in_channels, + out_channels, + kernel_size=(1, 3, 3), + padding=1, + bias=False, + indice_key=indice_key + 'bef') + self.act0_0 = build_activation_layer(act_cfg) + self.bn0_0 = build_norm_layer(norm_cfg, out_channels)[1] + + self.conv0_1 = SubMConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 3), + padding=1, + bias=False, + indice_key=indice_key + 'bef') + self.act0_1 = build_activation_layer(act_cfg) + self.bn0_1 = build_norm_layer(norm_cfg, out_channels)[1] + + self.conv1_0 = SubMConv3d( + in_channels, + out_channels, + kernel_size=(3, 1, 3), + padding=1, + bias=False, + indice_key=indice_key + 'bef') + self.act1_0 = build_activation_layer(act_cfg) + self.bn1_0 = build_norm_layer(norm_cfg, out_channels)[1] + + self.conv1_1 = SubMConv3d( + out_channels, + out_channels, + kernel_size=(1, 3, 3), + padding=1, + bias=False, + indice_key=indice_key + 'bef') + self.act1_1 = build_activation_layer(act_cfg) + self.bn1_1 = build_norm_layer(norm_cfg, out_channels)[1] + + def forward(self, x: SparseConvTensor) -> SparseConvTensor: + """Forward pass.""" + shortcut = self.conv0_0(x) + + shortcut.features = self.act0_0(shortcut.features) + shortcut.features = self.bn0_0(shortcut.features) + + shortcut = self.conv0_1(shortcut) + shortcut.features = self.act0_1(shortcut.features) + shortcut.features = self.bn0_1(shortcut.features) + + res = self.conv1_0(x) + res.features = self.act1_0(res.features) + res.features = self.bn1_0(res.features) + + res = self.conv1_1(res) + res.features = self.act1_1(res.features) + res.features = self.bn1_1(res.features) + + res.features = res.features + shortcut.features + + return res + + +class AsymmeDownBlock(BaseModule): + """Asymmetrical DownSample Block. + + Args: + in_channels (int): Input channels of the block. + out_channels (int): Output channels of the block. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for + normalization layer. + act_cfg (:obj:`ConfigDict` or dict): Config dict of activation layers. + Defaults to dict(type='LeakyReLU'). + pooling (bool): Whether pooling features at the end of + block. Defaults: True. + height_pooling (bool): Whether pooling features at + the height dimension. Defaults: False. + indice_key (str, optional): Name of indice tables. Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + norm_cfg: ConfigType, + act_cfg: ConfigType = dict(type='LeakyReLU'), + pooling: bool = True, + height_pooling: bool = False, + indice_key: Optional[str] = None): + super().__init__() + self.pooling = pooling + + self.conv0_0 = SubMConv3d( + in_channels, + out_channels, + kernel_size=(3, 1, 3), + padding=1, + bias=False, + indice_key=indice_key + 'bef') + self.act0_0 = build_activation_layer(act_cfg) + self.bn0_0 = build_norm_layer(norm_cfg, out_channels)[1] + + self.conv0_1 = SubMConv3d( + out_channels, + out_channels, + kernel_size=(1, 3, 3), + padding=1, + bias=False, + indice_key=indice_key + 'bef') + self.act0_1 = build_activation_layer(act_cfg) + self.bn0_1 = build_norm_layer(norm_cfg, out_channels)[1] + + self.conv1_0 = SubMConv3d( + in_channels, + out_channels, + kernel_size=(1, 3, 3), + padding=1, + bias=False, + indice_key=indice_key + 'bef') + self.act1_0 = build_activation_layer(act_cfg) + self.bn1_0 = build_norm_layer(norm_cfg, out_channels)[1] + + self.conv1_1 = SubMConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 3), + padding=1, + bias=False, + indice_key=indice_key + 'bef') + self.act1_1 = build_activation_layer(act_cfg) + self.bn1_1 = build_norm_layer(norm_cfg, out_channels)[1] + + if pooling: + if height_pooling: + self.pool = SparseConv3d( + out_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + indice_key=indice_key, + bias=False) + else: + self.pool = SparseConv3d( + out_channels, + out_channels, + kernel_size=3, + stride=(2, 2, 1), + padding=1, + indice_key=indice_key, + bias=False) + + def forward(self, x: SparseConvTensor) -> SparseConvTensor: + """Forward pass.""" + shortcut = self.conv0_0(x) + shortcut.features = self.act0_0(shortcut.features) + shortcut.features = self.bn0_0(shortcut.features) + + shortcut = self.conv0_1(shortcut) + shortcut.features = self.act0_1(shortcut.features) + shortcut.features = self.bn0_1(shortcut.features) + + res = self.conv1_0(x) + res.features = self.act1_0(res.features) + res.features = self.bn1_0(res.features) + + res = self.conv1_1(res) + res.features = self.act1_1(res.features) + res.features = self.bn1_1(res.features) + + res.features = res.features + shortcut.features + + if self.pooling: + pooled_res = self.pool(res) + return pooled_res, res + else: + return res + + +class AsymmeUpBlock(BaseModule): + """Asymmetrical UpSample Block. + + Args: + in_channels (int): Input channels of the block. + out_channels (int): Output channels of the block. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for + normalization layer. + act_cfg (:obj:`ConfigDict` or dict): Config dict of activation layers. + Defaults to dict(type='LeakyReLU'). + indice_key (str, optional): Name of indice tables. Defaults to None. + up_key (str, optional): Name of indice tables used in + SparseInverseConv3d. Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + norm_cfg: ConfigType, + act_cfg: ConfigType = dict(type='LeakyReLU'), + indice_key: Optional[str] = None, + up_key: Optional[str] = None): + super().__init__() + + self.trans_conv = SubMConv3d( + in_channels, + out_channels, + kernel_size=(3, 3, 3), + padding=1, + bias=False, + indice_key=indice_key + 'new_up') + self.trans_act = build_activation_layer(act_cfg) + self.trans_bn = build_norm_layer(norm_cfg, out_channels)[1] + + self.conv1 = SubMConv3d( + out_channels, + out_channels, + kernel_size=(1, 3, 3), + padding=1, + bias=False, + indice_key=indice_key) + self.act1 = build_activation_layer(act_cfg) + self.bn1 = build_norm_layer(norm_cfg, out_channels)[1] + + self.conv2 = SubMConv3d( + out_channels, + out_channels, + kernel_size=(3, 1, 3), + padding=1, + bias=False, + indice_key=indice_key) + self.act2 = build_activation_layer(act_cfg) + self.bn2 = build_norm_layer(norm_cfg, out_channels)[1] + + self.conv3 = SubMConv3d( + out_channels, + out_channels, + kernel_size=(3, 3, 3), + padding=1, + bias=False, + indice_key=indice_key) + self.act3 = build_activation_layer(act_cfg) + self.bn3 = build_norm_layer(norm_cfg, out_channels)[1] + + self.up_subm = SparseInverseConv3d( + out_channels, + out_channels, + kernel_size=3, + indice_key=up_key, + bias=False) + + def forward(self, x: SparseConvTensor, + skip: SparseConvTensor) -> SparseConvTensor: + """Forward pass.""" + x_trans = self.trans_conv(x) + x_trans.features = self.trans_act(x_trans.features) + x_trans.features = self.trans_bn(x_trans.features) + + # upsample + up = self.up_subm(x_trans) + + up.features = up.features + skip.features + + up = self.conv1(up) + up.features = self.act1(up.features) + up.features = self.bn1(up.features) + + up = self.conv2(up) + up.features = self.act2(up.features) + up.features = self.bn2(up.features) + + up = self.conv3(up) + up.features = self.act3(up.features) + up.features = self.bn3(up.features) + + return up + + +class DDCMBlock(BaseModule): + """Dimension-Decomposition based Context Modeling. + + Args: + in_channels (int): Input channels of the block. + out_channels (int): Output channels of the block. + norm_cfg (:obj:`ConfigDict` or dict): Config dict for + normalization layer. + act_cfg (:obj:`ConfigDict` or dict): Config dict of activation layers. + Defaults to dict(type='Sigmoid'). + indice_key (str, optional): Name of indice tables. Defaults to None. + """ + + def __init__(self, + in_channels: int, + out_channels: int, + norm_cfg: ConfigType, + act_cfg: ConfigType = dict(type='Sigmoid'), + indice_key: Optional[str] = None): + super().__init__() + + self.conv1 = SubMConv3d( + in_channels, + out_channels, + kernel_size=(3, 1, 1), + padding=1, + bias=False, + indice_key=indice_key) + self.bn1 = build_norm_layer(norm_cfg, out_channels)[1] + self.act1 = build_activation_layer(act_cfg) + + self.conv2 = SubMConv3d( + in_channels, + out_channels, + kernel_size=(1, 3, 1), + padding=1, + bias=False, + indice_key=indice_key) + self.bn2 = build_norm_layer(norm_cfg, out_channels)[1] + self.act2 = build_activation_layer(act_cfg) + + self.conv3 = SubMConv3d( + in_channels, + out_channels, + kernel_size=(1, 1, 3), + padding=1, + bias=False, + indice_key=indice_key) + self.bn3 = build_norm_layer(norm_cfg, out_channels)[1] + self.act3 = build_activation_layer(act_cfg) + + def forward(self, x: SparseConvTensor) -> SparseConvTensor: + """Forward pass.""" + shortcut = self.conv1(x) + shortcut.features = self.bn1(shortcut.features) + shortcut.features = self.act1(shortcut.features) + + shortcut2 = self.conv2(x) + shortcut2.features = self.bn2(shortcut2.features) + shortcut2.features = self.act2(shortcut2.features) + + shortcut3 = self.conv3(x) + shortcut3.features = self.bn3(shortcut3.features) + shortcut3.features = self.act3(shortcut3.features) + shortcut.features = shortcut.features + \ + shortcut2.features + shortcut3.features + + shortcut.features = shortcut.features * x.features + + return shortcut diff --git a/mmdet3d/models/segmentors/__init__.py b/mmdet3d/models/segmentors/__init__.py index 29fbc33e6a..bc458d311f 100644 --- a/mmdet3d/models/segmentors/__init__.py +++ b/mmdet3d/models/segmentors/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .base import Base3DSegmentor +from .cylinder3d import Cylinder3D from .encoder_decoder import EncoderDecoder3D -__all__ = ['Base3DSegmentor', 'EncoderDecoder3D'] +__all__ = ['Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D'] diff --git a/mmdet3d/models/segmentors/cylinder3d.py b/mmdet3d/models/segmentors/cylinder3d.py new file mode 100644 index 0000000000..b126607e6f --- /dev/null +++ b/mmdet3d/models/segmentors/cylinder3d.py @@ -0,0 +1,142 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict + +from torch import Tensor + +from mmdet3d.registry import MODELS +from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig +from ...structures.det3d_data_sample import SampleList +from .encoder_decoder import EncoderDecoder3D + + +@MODELS.register_module() +class Cylinder3D(EncoderDecoder3D): + """`Cylindrical and Asymmetrical 3D Convolution Networks for LiDAR + Segmentation. + + `_. + + Args: + voxel_encoder (dict or :obj:`ConfigDict`): The config for the + points2voxel encoder of segmentor. + backbone (dict or :obj:`ConfigDict`): The config for the backnone of + segmentor. + decode_head (dict or :obj:`ConfigDict`): The config for the decode + head of segmentor. + neck (dict or :obj:`ConfigDict`, optional): The config for the neck of + segmentor. Defaults to None. + auxiliary_head (dict or :obj:`ConfigDict` or List[dict or + :obj:`ConfigDict`], optional): The config for the auxiliary head of + segmentor. Defaults to None. + loss_regularization (dict or :obj:`ConfigDict` or List[dict or + :obj:`ConfigDict`], optional): The config for the regularization + loass. Defaults to None. + train_cfg (dict or :obj:`ConfigDict`, optional): The config for + training. Defaults to None. + test_cfg (dict or :obj:`ConfigDict`, optional): The config for testing. + Defaults to None. + data_preprocessor (dict or :obj:`ConfigDict`, optional): The + pre-process config of :class:`BaseDataPreprocessor`. + Defaults to None. + init_cfg (dict or :obj:`ConfigDict` or List[dict or :obj:`ConfigDict`], + optional): The weight initialized config for :class:`BaseModule`. + Defaults to None. + """ + + def __init__(self, + voxel_encoder: ConfigType, + backbone: ConfigType, + decode_head: ConfigType, + neck: OptConfigType = None, + auxiliary_head: OptConfigType = None, + loss_regularization: OptConfigType = None, + train_cfg: OptConfigType = None, + test_cfg: OptConfigType = None, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None) -> None: + super(Cylinder3D, self).__init__( + backbone=backbone, + decode_head=decode_head, + neck=neck, + auxiliary_head=auxiliary_head, + loss_regularization=loss_regularization, + train_cfg=train_cfg, + test_cfg=test_cfg, + data_preprocessor=data_preprocessor, + init_cfg=init_cfg) + + self.voxel_encoder = MODELS.build(voxel_encoder) + + def extract_feat(self, batch_inputs: dict) -> Tensor: + """Extract features from points.""" + encoded_feats = self.voxel_encoder(batch_inputs['voxels']['voxels'], + batch_inputs['voxels']['coors']) + batch_inputs['voxels']['voxel_coors'] = encoded_feats[1] + x = self.backbone(encoded_feats[0], encoded_feats[1], + len(batch_inputs['points'])) + if self.with_neck: + x = self.neck(x) + return x + + def loss(self, batch_inputs_dict: dict, + batch_data_samples: SampleList) -> Dict[str, Tensor]: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs_dict (dict): Input sample dict which + includes 'points' and 'imgs' keys. + + - points (List[Tensor]): Point cloud of each sample. + - imgs (Tensor, optional): Image tensor has shape (B, C, H, W). + batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data + samples. It usually includes information such as `metainfo` and + `gt_pts_seg`. + + Returns: + Dict[str, Tensor]: A dictionary of loss components. + """ + + # extract features using backbone + x = self.extract_feat(batch_inputs_dict) + losses = dict() + loss_decode = self._decode_head_forward_train(x, batch_data_samples) + losses.update(loss_decode) + + return losses + + def predict(self, + batch_inputs_dict: dict, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Simple test with single scene. + + Args: + batch_inputs_dict (dict): Input sample dict which includes 'points' + and 'imgs' keys. + + - points (List[Tensor]): Point cloud of each sample. + - imgs (Tensor, optional): Image tensor has shape (B, C, H, W). + batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data + samples. It usually includes information such as `metainfo` and + `gt_pts_seg`. + rescale (bool): Whether transform to original number of points. + Will be used for voxelization based segmentors. + Defaults to True. + + Returns: + List[:obj:`Det3DDataSample`]: Segmentation results of the input + points. Each Det3DDataSample usually contains: + + - ``pred_pts_seg`` (PixelData): Prediction of 3D semantic + segmentation. + """ + # 3D segmentation requires per-point prediction, so it's impossible + # to use down-sampling to get a batch of scenes with same num_points + # therefore, we only support testing one scene every time + x = self.extract_feat(batch_inputs_dict) + seg_pred_list = self.decode_head.predict(x, batch_inputs_dict, + batch_data_samples) + for i in range(len(seg_pred_list)): + seg_pred_list[i] = seg_pred_list[i].argmax(1).cpu() + + return self.postprocess_result(seg_pred_list, batch_data_samples) diff --git a/tests/test_models/test_backbones/test_cylinder3d_backbone.py b/tests/test_models/test_backbones/test_cylinder3d_backbone.py new file mode 100644 index 0000000000..ea6b3e7ba2 --- /dev/null +++ b/tests/test_models/test_backbones/test_cylinder3d_backbone.py @@ -0,0 +1,32 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmdet3d.registry import MODELS + + +def test_cylinder3d(): + if not torch.cuda.is_available(): + pytest.skip() + cfg = dict( + type='Asymm3DSpconv', + grid_size=[48, 32, 4], + input_channels=16, + base_channels=32, + norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.1)) + self = MODELS.build(cfg) + self.cuda() + + batch_size = 1 + coorx = torch.randint(0, 48, (50, 1)) + coory = torch.randint(0, 36, (50, 1)) + coorz = torch.randint(0, 4, (50, 1)) + coorbatch = torch.zeros(50, 1) + coors = torch.cat([coorbatch, coorx, coory, coorz], dim=1).cuda() + voxel_features = torch.rand(50, 16).cuda() + + # test forward + feature = self(voxel_features, coors, batch_size) + + assert feature.features.shape == (50, 128) + assert feature.indices.data.shape == (50, 4) diff --git a/tests/test_models/test_segmentors/test_cylinder3d.py b/tests/test_models/test_segmentors/test_cylinder3d.py new file mode 100644 index 0000000000..d036637e1f --- /dev/null +++ b/tests/test_models/test_segmentors/test_cylinder3d.py @@ -0,0 +1,42 @@ +import unittest + +import torch +from mmengine import DefaultScope + +from mmdet3d.registry import MODELS +from mmdet3d.testing import (create_detector_inputs, get_detector_cfg, + setup_seed) + + +class TestCylinder3D(unittest.TestCase): + + def test_cylinder3d(self): + import mmdet3d.models + + assert hasattr(mmdet3d.models, 'Cylinder3D') + DefaultScope.get_instance('test_cylinder3d', scope_name='mmdet3d') + setup_seed(0) + cylinder3d_cfg = get_detector_cfg( + 'cylinder3d/cylinder3d_4xb2_3x_semantickitti.py') + cylinder3d_cfg.decode_head['ignore_index'] = 1 + model = MODELS.build(cylinder3d_cfg) + num_gt_instance = 3 + packed_inputs = create_detector_inputs( + num_gt_instance=num_gt_instance, + num_classes=1, + with_pts_semantic_mask=True) + + if torch.cuda.is_available(): + model = model.cuda() + # test simple_test + with torch.no_grad(): + data = model.data_preprocessor(packed_inputs, True) + torch.cuda.empty_cache() + results = model.forward(**data, mode='predict') + self.assertEqual(len(results), 1) + self.assertIn('pts_semantic_mask', results[0].pred_pts_seg) + + losses = model.forward(**data, mode='loss') + + self.assertGreater(losses['decode.loss_ce'], 0) + self.assertGreater(losses['decode.loss_lovasz'], 0)