Skip to content

Commit

Permalink
[Feature] Add MinkUNet segmentor (#2294)
Browse files Browse the repository at this point in the history
* add cylindrical voxelization & voxel feature encoder

* add cylindrical voxelization & voxel feature encoder

* add voxel-wise label & voxelization UT

* fix vfe

* fix vfe UT

* rename voxel encoder & add more test case

* fix type hint

* temporarily refactoring mmcv's voxelize and dynamic in mmdet3d for data_preprocesser

* _forward

* del checkpoints

* add if tp

* add predict

* fix vfe init bug & fix UT

* add grid_size & move voxelization code

* fix import bug

* keep radian to follow origin

* add doc string

* fix type hint

* add minkunet voxelization and loss function

* fix data

* init train

* fix sparsetensor typehint

* rename dir

* fix data config

* fix data config

* fix batch_size & replace dynamic_scatter

* fix conflicts 2

* fix conflicts on s_70

* Alignment of the original implementation

* rename config

* add worker_init_fn_hook

* remove test_config & worker hook

* add UT

* fix polarmix UT

* add seed for cr0p5

* format

* rename SemanticKittiDataset

* add platte & fix visual bug

* add platte & fix data info bug

* fix ut

* fix semantic_kitti ut

* fix docstring

* fix config name

* rename layer

* fix doc string

* fix review

* remove filter data

* fix coors typo

* fix ut

* pred in segmentor

* fix get voxel seg

* resolve comments
  • Loading branch information
sunjiahao1999 authored Mar 23, 2023
1 parent be2029d commit ee6cc04
Show file tree
Hide file tree
Showing 15 changed files with 648 additions and 3 deletions.
29 changes: 29 additions & 0 deletions configs/_base_/models/minkunet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
model = dict(
type='MinkUNet',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='minkunet',
voxel_layer=dict(
max_num_points=-1,
point_cloud_range=[-100, -100, -20, 100, 100, 20],
voxel_size=[0.05, 0.05, 0.05],
max_voxels=(-1, -1)),
),
backbone=dict(
type='MinkUNetBackbone',
in_channels=4,
base_channels=32,
encoder_channels=[32, 64, 128, 256],
decoder_channels=[256, 128, 96, 96],
num_stages=4,
init_cfg=None),
decode_head=dict(
type='MinkUNetHead',
channels=96,
num_classes=19,
dropout_ratio=0,
loss_decode=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
ignore_index=19),
train_cfg=dict(),
test_cfg=dict())
13 changes: 13 additions & 0 deletions configs/minkunet/minkunet_w16_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py']

model = dict(
backbone=dict(
base_channels=16,
encoder_channels=[16, 32, 64, 128],
decoder_channels=[128, 64, 48, 48]),
decode_head=dict(channels=48))

# NOTE: Due to TorchSparse backend, the model performance is relatively
# dependent on random seeds, and if random seeds are not specified the
# model performance will be different (± 1.5 mIoU).
randomness = dict(seed=1588147245)
8 changes: 8 additions & 0 deletions configs/minkunet/minkunet_w20_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py']

model = dict(
backbone=dict(
base_channels=20,
encoder_channels=[20, 40, 81, 163],
decoder_channels=[163, 81, 61, 61]),
decode_head=dict(channels=61))
54 changes: 54 additions & 0 deletions configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
_base_ = [
'../_base_/datasets/semantickitti.py', '../_base_/models/minkunet.py',
'../_base_/default_runtime.py'
]

train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32',
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type='PointSegClassMapping'),
dict(
type='GlobalRotScaleTrans',
rot_range=[0., 6.28318531],
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0],
),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]

train_dataloader = dict(
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))

lr = 0.24
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='SGD', lr=lr, weight_decay=0.0001, momentum=0.9, nesterov=True))

param_scheduler = [
dict(
type='LinearLR', start_factor=0.008, by_epoch=False, begin=0, end=125),
dict(
type='CosineAnnealingLR',
begin=0,
T_max=15,
by_epoch=True,
eta_min=1e-5,
convert_to_iter_based=True)
]

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=15, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1))
randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
env_cfg = dict(cudnn_benchmark=True)
4 changes: 3 additions & 1 deletion mmdet3d/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .dgcnn import DGCNNBackbone
from .dla import DLANet
from .mink_resnet import MinkResNet
from .minkunet_backbone import MinkUNetBackbone
from .multi_backbone import MultiBackbone
from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_msg import PointNet2SAMSG
Expand All @@ -14,5 +15,6 @@
__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv'
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv',
'MinkUNetBackbone'
]
121 changes: 121 additions & 0 deletions mmdet3d/models/backbones/minkunet_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

from mmengine.model import BaseModule
from mmengine.registry import MODELS
from torch import Tensor, nn

from mmdet3d.models.layers import (TorchSparseConvModule,
TorchSparseResidualBlock)
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.utils import OptMultiConfig

if IS_TORCHSPARSE_AVAILABLE:
import torchsparse
from torchsparse.tensor import SparseTensor
else:
SparseTensor = None


@MODELS.register_module()
class MinkUNetBackbone(BaseModule):
r"""MinkUNet backbone with TorchSparse backend.
Refer to `implementation code <https://github.com/mit-han-lab/spvnas>`_.
Args:
in_channels (int): Number of input voxel feature channels.
Defaults to 4.
base_channels (int): The input channels for first encoder layer.
Defaults to 32.
encoder_channels (List[int]): Convolutional channels of each encode
layer. Defaults to [32, 64, 128, 256].
decoder_channels (List[int]): Convolutional channels of each decode
layer. Defaults to [256, 128, 96, 96].
num_stages (int): Number of stages in encoder and decoder.
Defaults to 4.
init_cfg (dict or :obj:`ConfigDict` or List[dict or :obj:`ConfigDict`]
, optional): Initialization config dict.
"""

def __init__(self,
in_channels: int = 4,
base_channels: int = 32,
encoder_channels: List[int] = [32, 64, 128, 256],
decoder_channels: List[int] = [256, 128, 96, 96],
num_stages: int = 4,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg)
assert num_stages == len(encoder_channels) == len(decoder_channels)
self.num_stages = num_stages
self.conv_input = nn.Sequential(
TorchSparseConvModule(in_channels, base_channels, kernel_size=3),
TorchSparseConvModule(base_channels, base_channels, kernel_size=3))
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()

encoder_channels.insert(0, base_channels)
decoder_channels.insert(0, encoder_channels[-1])
for i in range(num_stages):
self.encoder.append(
nn.Sequential(
TorchSparseConvModule(
encoder_channels[i],
encoder_channels[i],
kernel_size=2,
stride=2),
TorchSparseResidualBlock(
encoder_channels[i],
encoder_channels[i + 1],
kernel_size=3),
TorchSparseResidualBlock(
encoder_channels[i + 1],
encoder_channels[i + 1],
kernel_size=3)))

self.decoder.append(
nn.ModuleList([
TorchSparseConvModule(
decoder_channels[i],
decoder_channels[i + 1],
kernel_size=2,
stride=2,
transposed=True),
nn.Sequential(
TorchSparseResidualBlock(
decoder_channels[i + 1] + encoder_channels[-2 - i],
decoder_channels[i + 1],
kernel_size=3),
TorchSparseResidualBlock(
decoder_channels[i + 1],
decoder_channels[i + 1],
kernel_size=3))
]))

def forward(self, voxel_features: Tensor, coors: Tensor) -> SparseTensor:
"""Forward function.
Args:
voxel_features (Tensor): Voxel features in shape (N, C).
coors (Tensor): Coordinates in shape (N, 4),
the columns in the order of (x_idx, y_idx, z_idx, batch_idx).
Returns:
SparseTensor: Backbone features.
"""
x = torchsparse.SparseTensor(voxel_features, coors)
x = self.conv_input(x)
laterals = [x]
for encoder_layer in self.encoder:
x = encoder_layer(x)
laterals.append(x)
laterals = laterals[:-1][::-1]

decoder_outs = []
for i, decoder_layer in enumerate(self.decoder):
x = decoder_layer[0](x)
x = torchsparse.cat((x, laterals[i]))
x = decoder_layer[1](x)
decoder_outs.append(x)

return decoder_outs[-1]
77 changes: 77 additions & 0 deletions mmdet3d/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,33 @@ def voxelize(self, points: List[torch.Tensor],
coors.append(res_coors)
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
elif self.voxel_type == 'minkunet':
voxels, coors = [], []
voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size)
for i, (res, data_sample) in enumerate(zip(points, data_samples)):
res_coors = torch.round(res[:, :3] / voxel_size).int()
res_coors -= res_coors.min(0)[0]

res_coors_numpy = res_coors.cpu().numpy()
inds, voxel2point_map = self.sparse_quantize(
res_coors_numpy, return_index=True, return_inverse=True)
voxel2point_map = torch.from_numpy(voxel2point_map).cuda()
if self.training:
if len(inds) > 80000:
inds = np.random.choice(inds, 80000, replace=False)
inds = torch.from_numpy(inds).cuda()
data_sample.gt_pts_seg.voxel_semantic_mask \
= data_sample.gt_pts_seg.pts_semantic_mask[inds]
res_voxel_coors = res_coors[inds]
res_voxels = res[inds]
res_voxel_coors = F.pad(
res_voxel_coors, (0, 1), mode='constant', value=i)
data_sample.voxel2point_map = voxel2point_map.long()
voxels.append(res_voxels)
coors.append(res_voxel_coors)
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)

else:
raise ValueError(f'Invalid voxelization type {self.voxel_type}')

Expand Down Expand Up @@ -445,3 +472,53 @@ def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
_, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
res_coors, 'mean', True)
data_sample.gt_pts_seg.point2voxel_map = point2voxel_map

def ravel_hash(self, x: np.ndarray) -> np.ndarray:
"""Get voxel coordinates hash for np.unique().
Args:
x (np.ndarray): The voxel coordinates of points, Nx3.
Returns:
np.ndarray: Voxels coordinates hash.
"""
assert x.ndim == 2, x.shape

x = x - np.min(x, axis=0)
x = x.astype(np.uint64, copy=False)
xmax = np.max(x, axis=0).astype(np.uint64) + 1

h = np.zeros(x.shape[0], dtype=np.uint64)
for k in range(x.shape[1] - 1):
h += x[:, k]
h *= xmax[k + 1]
h += x[:, -1]
return h

def sparse_quantize(self,
coords: np.ndarray,
return_index: bool = False,
return_inverse: bool = False) -> List[np.ndarray]:
"""Sparse Quantization for voxel coordinates used in Minkunet.
Args:
coords (np.ndarray): The voxel coordinates of points, Nx3.
return_index (bool): Whether to return the indices of the
unique coords, shape (M,).
return_inverse (bool): Whether to return the indices of the
original coords shape (N,).
Returns:
List[np.ndarray] or None: Return index and inverse map if
return_index and return_inverse is True.
"""
_, indices, inverse_indices = np.unique(
self.ravel_hash(coords), return_index=True, return_inverse=True)
coords = coords[indices]

outputs = []
if return_index:
outputs += [indices]
if return_inverse:
outputs += [inverse_indices]
return outputs
6 changes: 5 additions & 1 deletion mmdet3d/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cylinder3d_head import Cylinder3DHead
from .dgcnn_head import DGCNNHead
from .minkunet_head import MinkUNetHead
from .paconv_head import PAConvHead
from .pointnet2_head import PointNet2Head

__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead']
__all__ = [
'PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead',
'MinkUNetHead'
]
Loading

0 comments on commit ee6cc04

Please sign in to comment.