Skip to content

Commit

Permalink
[Feature] Support PolarMix augmentation (#2265)
Browse files Browse the repository at this point in the history
* support polarmix

* Update __init__.py

* add UT

* use `BasePoints` instead of numpy

* Update transforms_3d.py

* Update transforms_3d.py

* Update test_transforms_3d.py

* update docs

* update polarmix without MultiImageMixDataset

* add comments

* fix UT

* update docstring

* fix yaw calculation

* fix UT

* refactor

* update

* update docs

* fix typo

* Update transforms_3d.py

* update ut

* fix typehint

* add prob argument
  • Loading branch information
Xiangxu-0103 authored Feb 23, 2023
1 parent 21de1af commit 60d848b
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 6 deletions.
18 changes: 18 additions & 0 deletions mmdet3d/datasets/seg3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,24 @@ def parse_data_info(self, info: dict) -> dict:

return info

def prepare_data(self, idx: int) -> dict:
"""Get data processed by ``self.pipeline``.
Args:
idx (int): The index of ``data_info``.
Returns:
dict: Results passed through ``self.pipeline``.
"""
if not self.test_mode:
data_info = self.get_data_info(idx)
# Pass the dataset to the pipeline during training to support mixed
# data augmentation, such as polarmix.
data_info['dataset'] = self
return self.pipeline(data_info)
else:
return super().prepare_data(idx)

def get_scene_idxs(self, scene_idxs: Union[None, str,
np.ndarray]) -> np.ndarray:
"""Compute scene_idxs for data sampling.
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
MultiViewWrapper, ObjectNameFilter, ObjectNoise,
ObjectRangeFilter, ObjectSample,
PhotoMetricDistortion3D, PointSample, PointShuffle,
PointsRangeFilter, RandomDropPointsColor,
PointsRangeFilter, PolarMix, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints, RandomResize3D,
RandomShiftScale, Resize3D, VoxelBasedPointSampler)

Expand All @@ -30,5 +30,5 @@
'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize',
'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D',
'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader',
'LidarDet3DInferencerLoader'
'LidarDet3DInferencerLoader', 'PolarMix'
]
173 changes: 171 additions & 2 deletions mmdet3d/datasets/transforms/transforms_3d.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
import warnings
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Sequence, Tuple, Union

import cv2
import mmcv
import numpy as np
import torch
from mmcv.transforms import BaseTransform, Compose, RandomResize, Resize
from mmdet.datasets.transforms import (PhotoMetricDistortion, RandomCrop,
RandomFlip)
from mmengine import is_tuple_of
from mmengine import is_list_of, is_tuple_of

from mmdet3d.models.task_modules import VoxelGenerator
from mmdet3d.registry import TRANSFORMS
Expand Down Expand Up @@ -2352,3 +2353,171 @@ def transform(self, input_dict: dict) -> dict:
if len(input_dict[key]) == 0:
input_dict.pop(key)
return input_dict


@TRANSFORMS.register_module()
class PolarMix(BaseTransform):
"""PolarMix data augmentation.
The polarmix transform steps are as follows:
1. Another random point cloud is picked by dataset.
2. Exchange sectors of two point clouds that are cut with certain
azimuth angles.
3. Cut point instances from picked point cloud, rotate them by multiple
azimuth angles, and paste the cut and rotated instances.
Required Keys:
- points (:obj:`BasePoints`)
- pts_semantic_mask (np.int64)
- dataset (:obj:`BaseDataset`)
Modified Keys:
- points (:obj:`BasePoints`)
- pts_semantic_mask (np.int64)
Args:
instance_classes (List[int]): Semantic masks which represent the
instance.
swap_ratio (float): Swap ratio of two point cloud. Defaults to 0.5.
rotate_paste_ratio (float): Rotate paste ratio. Defaults to 1.0.
pre_transform (Sequence[dict], optional): Sequence of transform object
or config dict to be composed. Defaults to None.
prob (float): The transformation probability. Defaults to 1.0.
"""

def __init__(self,
instance_classes: List[int],
swap_ratio: float = 0.5,
rotate_paste_ratio: float = 1.0,
pre_transform: Optional[Sequence[dict]] = None,
prob: float = 1.0) -> None:
assert is_list_of(instance_classes, int), \
'instance_classes should be a list of int'
self.instance_classes = instance_classes
self.swap_ratio = swap_ratio
self.rotate_paste_ratio = rotate_paste_ratio

self.prob = prob
if pre_transform is None:
self.pre_transform = None
else:
self.pre_transform = Compose(pre_transform)

def polar_mix_transform(self, input_dict: dict, mix_results: dict) -> dict:
"""PolarMix transform function.
Args:
input_dict (dict): Result dict from loading pipeline.
mix_results (dict): Mixed dict picked from dataset.
Returns:
dict: output dict after transformation.
"""
mix_points = mix_results['points']
mix_pts_semantic_mask = mix_results['pts_semantic_mask']

points = input_dict['points']
pts_semantic_mask = input_dict['pts_semantic_mask']

# 1. swap point cloud
if np.random.random() < self.swap_ratio:
start_angle = (np.random.random() - 1) * np.pi # -pi~0
end_angle = start_angle + np.pi
# calculate horizontal angle for each point
yaw = -torch.atan2(points.coord[:, 1], points.coord[:, 0])
mix_yaw = -torch.atan2(mix_points.coord[:, 1], mix_points.coord[:,
0])

# select points in sector
idx = (yaw <= start_angle) | (yaw >= end_angle)
mix_idx = (mix_yaw > start_angle) & (mix_yaw < end_angle)

# swap
points = points.cat([points[idx], mix_points[mix_idx]])
pts_semantic_mask = np.concatenate(
(pts_semantic_mask[idx.numpy()],
mix_pts_semantic_mask[mix_idx.numpy()]),
axis=0)

# 2. rotate-pasting
if np.random.random() < self.rotate_paste_ratio:
# extract instance points
instance_points, instance_pts_semantic_mask = [], []
for instance_class in self.instance_classes:
mix_idx = mix_pts_semantic_mask == instance_class
instance_points.append(mix_points[mix_idx])
instance_pts_semantic_mask.append(
mix_pts_semantic_mask[mix_idx])
instance_points = mix_points.cat(instance_points)
instance_pts_semantic_mask = np.concatenate(
instance_pts_semantic_mask, axis=0)

# rotate-copy
copy_points = [instance_points]
copy_pts_semantic_mask = [instance_pts_semantic_mask]
angle_list = [
np.random.random() * np.pi * 2 / 3,
(np.random.random() + 1) * np.pi * 2 / 3
]
for angle in angle_list:
new_points = instance_points.clone()
new_points.rotate(angle)
copy_points.append(new_points)
copy_pts_semantic_mask.append(instance_pts_semantic_mask)
copy_points = instance_points.cat(copy_points)
copy_pts_semantic_mask = np.concatenate(
copy_pts_semantic_mask, axis=0)

points = points.cat([points, copy_points])
pts_semantic_mask = np.concatenate(
(pts_semantic_mask, copy_pts_semantic_mask), axis=0)

input_dict['points'] = points
input_dict['pts_semantic_mask'] = pts_semantic_mask
return input_dict

def transform(self, input_dict: dict) -> dict:
"""PolarMix transform function.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: output dict after transformation.
"""
if np.random.rand() > self.prob:
return input_dict

assert 'dataset' in input_dict, \
'`dataset` is needed to pass through PolarMix, while not found.'
dataset = input_dict['dataset']

# get index of other point cloud
index = np.random.randint(0, len(dataset))

mix_results = dataset.get_data_info(index)

if self.pre_transform is not None:
# pre_transform may also require dataset
mix_results.update({'dataset': dataset})
# before polarmix need to go through
# the necessary pre_transform
mix_results = self.pre_transform(mix_results)
mix_results.pop('dataset')

input_dict = self.polar_mix_transform(input_dict, mix_results)

return input_dict

def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(instance_classes={self.instance_classes}, '
repr_str += f'swap_ratio={self.swap_ratio}, '
repr_str += f'rotate_paste_ratio={self.rotate_paste_ratio}, '
repr_str += f'pre_transform={self.pre_transform}, '
repr_str += f'prob={self.prob})'
return repr_str
127 changes: 125 additions & 2 deletions tests/test_datasets/test_transforms/test_transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@
import torch
from mmengine.testing import assert_allclose

from mmdet3d.datasets import GlobalAlignment, RandomFlip3D
from mmdet3d.datasets.transforms import GlobalRotScaleTrans
from mmdet3d.datasets import (GlobalAlignment, RandomFlip3D,
SemanticKITTIDataset)
from mmdet3d.datasets.transforms import GlobalRotScaleTrans, PolarMix
from mmdet3d.structures import LiDARPoints
from mmdet3d.testing import create_data_info_after_loading
from mmdet3d.utils import register_all_modules

register_all_modules()


class TestGlobalRotScaleTrans(unittest.TestCase):
Expand Down Expand Up @@ -99,3 +104,121 @@ def test_global_alignment(self):
# assert the rot metric
with self.assertRaises(AssertionError):
global_align_transform(data_info)


class TestPolarMix(unittest.TestCase):

def setUp(self):
self.pre_transform = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32'),
dict(type='PointSegClassMapping'),
]
classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
'sidewalk', 'other-ground', 'building', 'fence',
'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign')
palette = [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
[255, 187, 120],
[188, 189, 34],
[140, 86, 75],
[255, 152, 150],
[214, 39, 40],
[197, 176, 213],
[148, 103, 189],
[196, 156, 148],
[23, 190, 207],
[247, 182, 210],
[219, 219, 141],
[255, 127, 14],
[158, 218, 229],
[44, 160, 44],
[112, 128, 144],
[227, 119, 194],
[82, 84, 163],
]
seg_label_mapping = {
0: 0, # "unlabeled"
1: 0, # "outlier" mapped to "unlabeled" --------------mapped
10: 1, # "car"
11: 2, # "bicycle"
13: 5, # "bus" mapped to "other-vehicle" --------------mapped
15: 3, # "motorcycle"
16: 5, # "on-rails" mapped to "other-vehicle" ---------mapped
18: 4, # "truck"
20: 5, # "other-vehicle"
30: 6, # "person"
31: 7, # "bicyclist"
32: 8, # "motorcyclist"
40: 9, # "road"
44: 10, # "parking"
48: 11, # "sidewalk"
49: 12, # "other-ground"
50: 13, # "building"
51: 14, # "fence"
52: 0, # "other-structure" mapped to "unlabeled" ------mapped
60: 9, # "lane-marking" to "road" ---------------------mapped
70: 15, # "vegetation"
71: 16, # "trunk"
72: 17, # "terrain"
80: 18, # "pole"
81: 19, # "traffic-sign"
99: 0, # "other-object" to "unlabeled" ----------------mapped
252: 1, # "moving-car" to "car" ------------------------mapped
253: 7, # "moving-bicyclist" to "bicyclist" ------------mapped
254: 6, # "moving-person" to "person" ------------------mapped
255: 8, # "moving-motorcyclist" to "motorcyclist" ------mapped
256: 5, # "moving-on-rails" mapped to "other-vehic------mapped
257: 5, # "moving-bus" mapped to "other-vehicle" -------mapped
258: 4, # "moving-truck" to "truck" --------------------mapped
259: 5 # "moving-other"-vehicle to "other-vehicle"-----mapped
}
max_label = 259
self.dataset = SemanticKITTIDataset(
'./tests/data/semantickitti/',
'semantickitti_infos.pkl',
metainfo=dict(
classes=classes,
palette=palette,
seg_label_mapping=seg_label_mapping,
max_label=max_label),
data_prefix=dict(
pts='sequences/00/velodyne',
pts_semantic_mask='sequences/00/labels'),
pipeline=[],
modality=dict(use_lidar=True, use_camera=False))
points = np.random.random((100, 4))
self.results = {
'points': LiDARPoints(points, points_dim=4),
'pts_semantic_mask': np.random.randint(0, 20, (100, )),
'dataset': self.dataset
}

def test_transform(self):
# test assertion for invalid instance_classes
with self.assertRaises(AssertionError):
transform = PolarMix(instance_classes=1)

with self.assertRaises(AssertionError):
transform = PolarMix(instance_classes=[1.0, 2.0])

transform = PolarMix(
instance_classes=[1, 2],
swap_ratio=1.0,
pre_transform=self.pre_transform)
results = transform.transform(copy.deepcopy(self.results))
self.assertTrue(results['points'].shape[0] ==
results['pts_semantic_mask'].shape[0])

0 comments on commit 60d848b

Please sign in to comment.