Skip to content

Commit

Permalink
[Feature] Cylinder3d segmentor (#2344)
Browse files Browse the repository at this point in the history
* update

* add cylinder3d_backbone

* add test segmentor

* add cfg

* add test backbone

* rename test cylinder3d backbone

* midway

* update, pass validation

* fix test

* update cfg
  • Loading branch information
xizaoqu authored Mar 20, 2023
1 parent afa4479 commit dfcf542
Show file tree
Hide file tree
Showing 11 changed files with 790 additions and 12 deletions.
41 changes: 41 additions & 0 deletions configs/_base_/models/cylinder3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
grid_shape = [480, 360, 32]
model = dict(
type='Cylinder3D',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='cylindrical',
voxel_layer=dict(
grid_shape=grid_shape,
point_cloud_range=[0, -3.14159265359, -4, 50, 3.14159265359, 2],
max_num_points=-1,
max_voxels=-1,
),
),
voxel_encoder=dict(
type='SegVFE',
feat_channels=[64, 128, 256, 256],
in_channels=6,
with_voxel_center=True,
feat_compression=16,
return_point_feats=False),
backbone=dict(
type='Asymm3DSpconv',
grid_size=grid_shape,
input_channels=16,
base_channels=32,
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.1)),
decode_head=dict(
type='Cylinder3DHead',
channels=128,
num_classes=20,
loss_ce=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=False,
class_weight=None,
loss_weight=1.0),
loss_lovasz=dict(type='LovaszLoss', loss_weight=1.0, reduction='none'),
),
train_cfg=None,
test_cfg=dict(mode='whole'),
)
37 changes: 37 additions & 0 deletions configs/cylinder3d/cylinder3d_4xb2_3x_semantickitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
_base_ = [
'../_base_/datasets/semantickitti.py', '../_base_/models/cylinder3d.py',
'../_base_/default_runtime.py'
]

# optimizer
# This schedule is mainly used by models on nuScenes dataset
lr = 0.001
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01))

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=36, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

# learning rate
param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0,
end=1000),
dict(
type='MultiStepLR',
begin=0,
end=36,
by_epoch=True,
milestones=[30],
gamma=0.1)
]

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (4 samples per GPU).
# auto_scale_lr = dict(enable=False, base_batch_size=32)

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5))
4 changes: 2 additions & 2 deletions mmdet3d/datasets/seg3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ def parse_data_info(self, info: dict) -> dict:
osp.join(
self.data_prefix.get('pts', ''),
info['lidar_points']['lidar_path'])

info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
if 'num_pts_feats' in info['lidar_points']:
info['num_pts_feats'] = info['lidar_points']['num_pts_feats']
info['lidar_path'] = info['lidar_points']['lidar_path']

if self.modality['use_camera']:
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt

from .cylinder3d import Asymm3DSpconv
from .dgcnn import DGCNNBackbone
from .dla import DLANet
from .mink_resnet import MinkResNet
Expand All @@ -13,5 +14,5 @@
__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone', 'DLANet', 'MinkResNet'
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv'
]
110 changes: 110 additions & 0 deletions mmdet3d/models/backbones/cylinder3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) OpenMMLab. All rights reserved.
r"""Modified from Cylinder3D.
Please refer to `Cylinder3D github page
<https://github.com/xinge008/Cylinder3D>`_ for details
"""

from typing import List

import numpy as np
import torch
from mmcv.ops import SparseConvTensor
from mmengine.model import BaseModule

from mmdet3d.models.layers.sparse_block import (AsymmeDownBlock, AsymmeUpBlock,
AsymmResBlock, DDCMBlock)
from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType


@MODELS.register_module()
class Asymm3DSpconv(BaseModule):
"""Asymmetrical 3D convolution networks.
Args:
grid_size (int): Size of voxel grids.
input_channels (int): Input channels of the block.
base_channels (int): Initial size of feature channels before
feeding into Encoder-Decoder structure. Defaults to 16.
backbone_depth (int): The depth of backbone. The backbone contains
downblocks and upblocks with the number of backbone_depth.
height_pooing (List[bool]): List indicating which downblocks perform
height pooling.
norm_cfg (:obj:`ConfigDict` or dict): Config dict for normalization
layer. Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01)).
init_cfg (dict, optional): Initialization config.
Defaults to None.
"""

def __init__(self,
grid_size: int,
input_channels: int,
base_channels: int = 16,
backbone_depth: int = 4,
height_pooing: List[bool] = [True, True, False, False],
norm_cfg: ConfigType = dict(
type='BN1d', eps=1e-3, momentum=0.01),
init_cfg=None):
super().__init__(init_cfg=init_cfg)

self.grid_size = grid_size
self.backbone_depth = backbone_depth
self.down_context = AsymmResBlock(
input_channels, base_channels, indice_key='pre', norm_cfg=norm_cfg)

self.down_block_list = torch.nn.ModuleList()
self.up_block_list = torch.nn.ModuleList()
for i in range(self.backbone_depth):
self.down_block_list.append(
AsymmeDownBlock(
2**i * base_channels,
2**(i + 1) * base_channels,
height_pooling=height_pooing[i],
indice_key='down' + str(i),
norm_cfg=norm_cfg))
if i == self.backbone_depth - 1:
self.up_block_list.append(
AsymmeUpBlock(
2**(i + 1) * base_channels,
2**(i + 1) * base_channels,
up_key='down' + str(i),
indice_key='up' + str(self.backbone_depth - 1 - i),
norm_cfg=norm_cfg))
else:
self.up_block_list.append(
AsymmeUpBlock(
2**(i + 2) * base_channels,
2**(i + 1) * base_channels,
up_key='down' + str(i),
indice_key='up' + str(self.backbone_depth - 1 - i),
norm_cfg=norm_cfg))

self.ddcm = DDCMBlock(
2 * base_channels,
2 * base_channels,
indice_key='ddcm',
norm_cfg=norm_cfg)

def forward(self, voxel_features: torch.Tensor, coors: torch.Tensor,
batch_size: int) -> SparseConvTensor:
"""Forward pass."""
coors = coors.int()
ret = SparseConvTensor(voxel_features, coors, np.array(self.grid_size),
batch_size)
ret = self.down_context(ret)

down_skip_list = []
down_pool = ret
for i in range(self.backbone_depth):
down_pool, down_skip = self.down_block_list[i](down_pool)
down_skip_list.append(down_skip)

up = down_pool
for i in range(self.backbone_depth - 1, -1, -1):
up = self.up_block_list[i](up, down_skip_list[i])

ddcm = self.ddcm(up)
ddcm.features = torch.cat((ddcm.features, up.features), 1)

return ddcm
6 changes: 2 additions & 4 deletions mmdet3d/models/decode_heads/cylinder3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Cylinder3DHead(Base3DDecodeHead):
conv_seg_kernel_size (int): The kernel size used in conv_seg.
Defaults to 3.
ignore_index (int): The label index to be ignored. When using masked
BCE loss, ignore_index should be set to None. Defaults to 0.
BCE loss, ignore_index should be set to None. Defaults to 19.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`],
optional): Initialization config dict. Defaults to None.
"""
Expand All @@ -59,7 +59,7 @@ def __init__(self,
loss_lovasz: ConfigType = dict(
type='LovaszLoss', loss_weight=1.0),
conv_seg_kernel_size: int = 3,
ignore_index: int = 0,
ignore_index: int = 19,
init_cfg: OptMultiConfig = None) -> None:
super(Cylinder3DHead, self).__init__(
channels=channels,
Expand Down Expand Up @@ -116,8 +116,6 @@ def loss_by_feat(self, seg_logit: SparseConvTensor,
loss = dict()
loss['loss_ce'] = self.loss_ce(
seg_logit_feat, seg_label, ignore_index=self.ignore_index)
seg_logit_feat = seg_logit_feat.permute(1, 0)[None, :, :,
None] # pseudo BCHW
loss['loss_lovasz'] = self.loss_lovasz(
seg_logit_feat, seg_label, ignore_index=self.ignore_index)

Expand Down
Loading

0 comments on commit dfcf542

Please sign in to comment.