From 60d848b3d379dbb8673e51d475d1f00f45b34682 Mon Sep 17 00:00:00 2001 From: Xiang Xu Date: Thu, 23 Feb 2023 13:05:13 +0800 Subject: [PATCH] [Feature] Support `PolarMix` augmentation (#2265) * support polarmix * Update __init__.py * add UT * use `BasePoints` instead of numpy * Update transforms_3d.py * Update transforms_3d.py * Update test_transforms_3d.py * update docs * update polarmix without MultiImageMixDataset * add comments * fix UT * update docstring * fix yaw calculation * fix UT * refactor * update * update docs * fix typo * Update transforms_3d.py * update ut * fix typehint * add prob argument --- mmdet3d/datasets/seg3d_dataset.py | 18 ++ mmdet3d/datasets/transforms/__init__.py | 4 +- mmdet3d/datasets/transforms/transforms_3d.py | 173 +++++++++++++++++- .../test_transforms/test_transforms_3d.py | 127 ++++++++++++- 4 files changed, 316 insertions(+), 6 deletions(-) diff --git a/mmdet3d/datasets/seg3d_dataset.py b/mmdet3d/datasets/seg3d_dataset.py index 513883a1ec..42025dee49 100644 --- a/mmdet3d/datasets/seg3d_dataset.py +++ b/mmdet3d/datasets/seg3d_dataset.py @@ -283,6 +283,24 @@ def parse_data_info(self, info: dict) -> dict: return info + def prepare_data(self, idx: int) -> dict: + """Get data processed by ``self.pipeline``. + + Args: + idx (int): The index of ``data_info``. + + Returns: + dict: Results passed through ``self.pipeline``. + """ + if not self.test_mode: + data_info = self.get_data_info(idx) + # Pass the dataset to the pipeline during training to support mixed + # data augmentation, such as polarmix. + data_info['dataset'] = self + return self.pipeline(data_info) + else: + return super().prepare_data(idx) + def get_scene_idxs(self, scene_idxs: Union[None, str, np.ndarray]) -> np.ndarray: """Compute scene_idxs for data sampling. diff --git a/mmdet3d/datasets/transforms/__init__.py b/mmdet3d/datasets/transforms/__init__.py index 72d5a8f42b..c8969f8b60 100644 --- a/mmdet3d/datasets/transforms/__init__.py +++ b/mmdet3d/datasets/transforms/__init__.py @@ -14,7 +14,7 @@ MultiViewWrapper, ObjectNameFilter, ObjectNoise, ObjectRangeFilter, ObjectSample, PhotoMetricDistortion3D, PointSample, PointShuffle, - PointsRangeFilter, RandomDropPointsColor, + PointsRangeFilter, PolarMix, RandomDropPointsColor, RandomFlip3D, RandomJitterPoints, RandomResize3D, RandomShiftScale, Resize3D, VoxelBasedPointSampler) @@ -30,5 +30,5 @@ 'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize', 'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D', 'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader', - 'LidarDet3DInferencerLoader' + 'LidarDet3DInferencerLoader', 'PolarMix' ] diff --git a/mmdet3d/datasets/transforms/transforms_3d.py b/mmdet3d/datasets/transforms/transforms_3d.py index 495f0c1e2d..dbdbf2a45c 100644 --- a/mmdet3d/datasets/transforms/transforms_3d.py +++ b/mmdet3d/datasets/transforms/transforms_3d.py @@ -1,15 +1,16 @@ # Copyright (c) OpenMMLab. All rights reserved. import random import warnings -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Sequence, Tuple, Union import cv2 import mmcv import numpy as np +import torch from mmcv.transforms import BaseTransform, Compose, RandomResize, Resize from mmdet.datasets.transforms import (PhotoMetricDistortion, RandomCrop, RandomFlip) -from mmengine import is_tuple_of +from mmengine import is_list_of, is_tuple_of from mmdet3d.models.task_modules import VoxelGenerator from mmdet3d.registry import TRANSFORMS @@ -2352,3 +2353,171 @@ def transform(self, input_dict: dict) -> dict: if len(input_dict[key]) == 0: input_dict.pop(key) return input_dict + + +@TRANSFORMS.register_module() +class PolarMix(BaseTransform): + """PolarMix data augmentation. + + The polarmix transform steps are as follows: + + 1. Another random point cloud is picked by dataset. + 2. Exchange sectors of two point clouds that are cut with certain + azimuth angles. + 3. Cut point instances from picked point cloud, rotate them by multiple + azimuth angles, and paste the cut and rotated instances. + + Required Keys: + + - points (:obj:`BasePoints`) + - pts_semantic_mask (np.int64) + - dataset (:obj:`BaseDataset`) + + Modified Keys: + + - points (:obj:`BasePoints`) + - pts_semantic_mask (np.int64) + + Args: + instance_classes (List[int]): Semantic masks which represent the + instance. + swap_ratio (float): Swap ratio of two point cloud. Defaults to 0.5. + rotate_paste_ratio (float): Rotate paste ratio. Defaults to 1.0. + pre_transform (Sequence[dict], optional): Sequence of transform object + or config dict to be composed. Defaults to None. + prob (float): The transformation probability. Defaults to 1.0. + """ + + def __init__(self, + instance_classes: List[int], + swap_ratio: float = 0.5, + rotate_paste_ratio: float = 1.0, + pre_transform: Optional[Sequence[dict]] = None, + prob: float = 1.0) -> None: + assert is_list_of(instance_classes, int), \ + 'instance_classes should be a list of int' + self.instance_classes = instance_classes + self.swap_ratio = swap_ratio + self.rotate_paste_ratio = rotate_paste_ratio + + self.prob = prob + if pre_transform is None: + self.pre_transform = None + else: + self.pre_transform = Compose(pre_transform) + + def polar_mix_transform(self, input_dict: dict, mix_results: dict) -> dict: + """PolarMix transform function. + + Args: + input_dict (dict): Result dict from loading pipeline. + mix_results (dict): Mixed dict picked from dataset. + + Returns: + dict: output dict after transformation. + """ + mix_points = mix_results['points'] + mix_pts_semantic_mask = mix_results['pts_semantic_mask'] + + points = input_dict['points'] + pts_semantic_mask = input_dict['pts_semantic_mask'] + + # 1. swap point cloud + if np.random.random() < self.swap_ratio: + start_angle = (np.random.random() - 1) * np.pi # -pi~0 + end_angle = start_angle + np.pi + # calculate horizontal angle for each point + yaw = -torch.atan2(points.coord[:, 1], points.coord[:, 0]) + mix_yaw = -torch.atan2(mix_points.coord[:, 1], mix_points.coord[:, + 0]) + + # select points in sector + idx = (yaw <= start_angle) | (yaw >= end_angle) + mix_idx = (mix_yaw > start_angle) & (mix_yaw < end_angle) + + # swap + points = points.cat([points[idx], mix_points[mix_idx]]) + pts_semantic_mask = np.concatenate( + (pts_semantic_mask[idx.numpy()], + mix_pts_semantic_mask[mix_idx.numpy()]), + axis=0) + + # 2. rotate-pasting + if np.random.random() < self.rotate_paste_ratio: + # extract instance points + instance_points, instance_pts_semantic_mask = [], [] + for instance_class in self.instance_classes: + mix_idx = mix_pts_semantic_mask == instance_class + instance_points.append(mix_points[mix_idx]) + instance_pts_semantic_mask.append( + mix_pts_semantic_mask[mix_idx]) + instance_points = mix_points.cat(instance_points) + instance_pts_semantic_mask = np.concatenate( + instance_pts_semantic_mask, axis=0) + + # rotate-copy + copy_points = [instance_points] + copy_pts_semantic_mask = [instance_pts_semantic_mask] + angle_list = [ + np.random.random() * np.pi * 2 / 3, + (np.random.random() + 1) * np.pi * 2 / 3 + ] + for angle in angle_list: + new_points = instance_points.clone() + new_points.rotate(angle) + copy_points.append(new_points) + copy_pts_semantic_mask.append(instance_pts_semantic_mask) + copy_points = instance_points.cat(copy_points) + copy_pts_semantic_mask = np.concatenate( + copy_pts_semantic_mask, axis=0) + + points = points.cat([points, copy_points]) + pts_semantic_mask = np.concatenate( + (pts_semantic_mask, copy_pts_semantic_mask), axis=0) + + input_dict['points'] = points + input_dict['pts_semantic_mask'] = pts_semantic_mask + return input_dict + + def transform(self, input_dict: dict) -> dict: + """PolarMix transform function. + + Args: + input_dict (dict): Result dict from loading pipeline. + + Returns: + dict: output dict after transformation. + """ + if np.random.rand() > self.prob: + return input_dict + + assert 'dataset' in input_dict, \ + '`dataset` is needed to pass through PolarMix, while not found.' + dataset = input_dict['dataset'] + + # get index of other point cloud + index = np.random.randint(0, len(dataset)) + + mix_results = dataset.get_data_info(index) + + if self.pre_transform is not None: + # pre_transform may also require dataset + mix_results.update({'dataset': dataset}) + # before polarmix need to go through + # the necessary pre_transform + mix_results = self.pre_transform(mix_results) + mix_results.pop('dataset') + + input_dict = self.polar_mix_transform(input_dict, mix_results) + + return input_dict + + def __repr__(self) -> str: + """str: Return a string that describes the module.""" + repr_str = self.__class__.__name__ + repr_str += f'(instance_classes={self.instance_classes}, ' + repr_str += f'swap_ratio={self.swap_ratio}, ' + repr_str += f'rotate_paste_ratio={self.rotate_paste_ratio}, ' + repr_str += f'pre_transform={self.pre_transform}, ' + repr_str += f'prob={self.prob})' + return repr_str diff --git a/tests/test_datasets/test_transforms/test_transforms_3d.py b/tests/test_datasets/test_transforms/test_transforms_3d.py index 81e222000c..b66c3bb3c1 100644 --- a/tests/test_datasets/test_transforms/test_transforms_3d.py +++ b/tests/test_datasets/test_transforms/test_transforms_3d.py @@ -6,9 +6,14 @@ import torch from mmengine.testing import assert_allclose -from mmdet3d.datasets import GlobalAlignment, RandomFlip3D -from mmdet3d.datasets.transforms import GlobalRotScaleTrans +from mmdet3d.datasets import (GlobalAlignment, RandomFlip3D, + SemanticKITTIDataset) +from mmdet3d.datasets.transforms import GlobalRotScaleTrans, PolarMix +from mmdet3d.structures import LiDARPoints from mmdet3d.testing import create_data_info_after_loading +from mmdet3d.utils import register_all_modules + +register_all_modules() class TestGlobalRotScaleTrans(unittest.TestCase): @@ -99,3 +104,121 @@ def test_global_alignment(self): # assert the rot metric with self.assertRaises(AssertionError): global_align_transform(data_info) + + +class TestPolarMix(unittest.TestCase): + + def setUp(self): + self.pre_transform = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=4, + use_dim=4), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=False, + with_seg_3d=True, + seg_3d_dtype='np.int32'), + dict(type='PointSegClassMapping'), + ] + classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus', + 'person', 'bicyclist', 'motorcyclist', 'road', 'parking', + 'sidewalk', 'other-ground', 'building', 'fence', + 'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign') + palette = [ + [174, 199, 232], + [152, 223, 138], + [31, 119, 180], + [255, 187, 120], + [188, 189, 34], + [140, 86, 75], + [255, 152, 150], + [214, 39, 40], + [197, 176, 213], + [148, 103, 189], + [196, 156, 148], + [23, 190, 207], + [247, 182, 210], + [219, 219, 141], + [255, 127, 14], + [158, 218, 229], + [44, 160, 44], + [112, 128, 144], + [227, 119, 194], + [82, 84, 163], + ] + seg_label_mapping = { + 0: 0, # "unlabeled" + 1: 0, # "outlier" mapped to "unlabeled" --------------mapped + 10: 1, # "car" + 11: 2, # "bicycle" + 13: 5, # "bus" mapped to "other-vehicle" --------------mapped + 15: 3, # "motorcycle" + 16: 5, # "on-rails" mapped to "other-vehicle" ---------mapped + 18: 4, # "truck" + 20: 5, # "other-vehicle" + 30: 6, # "person" + 31: 7, # "bicyclist" + 32: 8, # "motorcyclist" + 40: 9, # "road" + 44: 10, # "parking" + 48: 11, # "sidewalk" + 49: 12, # "other-ground" + 50: 13, # "building" + 51: 14, # "fence" + 52: 0, # "other-structure" mapped to "unlabeled" ------mapped + 60: 9, # "lane-marking" to "road" ---------------------mapped + 70: 15, # "vegetation" + 71: 16, # "trunk" + 72: 17, # "terrain" + 80: 18, # "pole" + 81: 19, # "traffic-sign" + 99: 0, # "other-object" to "unlabeled" ----------------mapped + 252: 1, # "moving-car" to "car" ------------------------mapped + 253: 7, # "moving-bicyclist" to "bicyclist" ------------mapped + 254: 6, # "moving-person" to "person" ------------------mapped + 255: 8, # "moving-motorcyclist" to "motorcyclist" ------mapped + 256: 5, # "moving-on-rails" mapped to "other-vehic------mapped + 257: 5, # "moving-bus" mapped to "other-vehicle" -------mapped + 258: 4, # "moving-truck" to "truck" --------------------mapped + 259: 5 # "moving-other"-vehicle to "other-vehicle"-----mapped + } + max_label = 259 + self.dataset = SemanticKITTIDataset( + './tests/data/semantickitti/', + 'semantickitti_infos.pkl', + metainfo=dict( + classes=classes, + palette=palette, + seg_label_mapping=seg_label_mapping, + max_label=max_label), + data_prefix=dict( + pts='sequences/00/velodyne', + pts_semantic_mask='sequences/00/labels'), + pipeline=[], + modality=dict(use_lidar=True, use_camera=False)) + points = np.random.random((100, 4)) + self.results = { + 'points': LiDARPoints(points, points_dim=4), + 'pts_semantic_mask': np.random.randint(0, 20, (100, )), + 'dataset': self.dataset + } + + def test_transform(self): + # test assertion for invalid instance_classes + with self.assertRaises(AssertionError): + transform = PolarMix(instance_classes=1) + + with self.assertRaises(AssertionError): + transform = PolarMix(instance_classes=[1.0, 2.0]) + + transform = PolarMix( + instance_classes=[1, 2], + swap_ratio=1.0, + pre_transform=self.pre_transform) + results = transform.transform(copy.deepcopy(self.results)) + self.assertTrue(results['points'].shape[0] == + results['pts_semantic_mask'].shape[0])