From 952a59235771c9226b8f5016d3d8e615b8eb298b Mon Sep 17 00:00:00 2001 From: SekiroRong <64967662+SekiroRong@users.noreply.github.com> Date: Thu, 5 Jan 2023 09:42:01 +0800 Subject: [PATCH] [Feat]: Support PETR in 1.1 in `projects` (#2175) * rebase * petr3d- * petr3d-to-petr * delete_NormalizeMultiviewImage * rename_PETR * rename_PETR * fix_bug * fix_bug * fix_bug * fix_bug * fix_bug * fix_bug * fix_bug * revise * remove_builder * remove_builder * remove_use_external * remove_use_external * remove_PadMultiViewImage * remove_PadMultiViewImage * remove-AddCamInfo * remove-LidarBox3dVersionTransfrom * remove-LidarBox3dVersionTransfrom-and-AddCamInfo * fix__init__ * remove-redundent-config * code-polish * remove-builder * remove-builder * remove-redundent-files * replace-forward-train-and-test * remove-redundent__init__ * remove_petr * remove-hierarchtecture --- projects/PETR/README.md | 63 ++ .../petr/petr_vovnet_gridmask_p4_800x320.py | 357 ++++++++ projects/PETR/petr/__init__.py | 24 + projects/PETR/petr/cp_fpn.py | 211 +++++ projects/PETR/petr/grid_mask.py | 146 ++++ projects/PETR/petr/hungarian_assigner_3d.py | 142 +++ projects/PETR/petr/match_cost.py | 338 +++++++ projects/PETR/petr/nms_free_coder.py | 246 ++++++ projects/PETR/petr/petr.py | 299 +++++++ projects/PETR/petr/petr_head.py | 825 ++++++++++++++++++ projects/PETR/petr/petr_transformer.py | 540 ++++++++++++ projects/PETR/petr/positional_encoding.py | 171 ++++ projects/PETR/petr/transforms_3d.py | 207 +++++ projects/PETR/petr/utils.py | 69 ++ projects/PETR/petr/vovnetcp.py | 475 ++++++++++ 15 files changed, 4113 insertions(+) create mode 100644 projects/PETR/README.md create mode 100644 projects/PETR/config/petr/petr_vovnet_gridmask_p4_800x320.py create mode 100644 projects/PETR/petr/__init__.py create mode 100644 projects/PETR/petr/cp_fpn.py create mode 100644 projects/PETR/petr/grid_mask.py create mode 100644 projects/PETR/petr/hungarian_assigner_3d.py create mode 100644 projects/PETR/petr/match_cost.py create mode 100644 projects/PETR/petr/nms_free_coder.py create mode 100644 projects/PETR/petr/petr.py create mode 100644 projects/PETR/petr/petr_head.py create mode 100644 projects/PETR/petr/petr_transformer.py create mode 100644 projects/PETR/petr/positional_encoding.py create mode 100644 projects/PETR/petr/transforms_3d.py create mode 100644 projects/PETR/petr/utils.py create mode 100644 projects/PETR/petr/vovnetcp.py diff --git a/projects/PETR/README.md b/projects/PETR/README.md new file mode 100644 index 0000000000..38fc58b516 --- /dev/null +++ b/projects/PETR/README.md @@ -0,0 +1,63 @@ +# PETR + +This is an README for `PETR`. + +## Description + +Author: @SekiroRong. +This is an implementation of *PETR*. + +## Usage + + + +### Training commands + +In MMDet3D's root directory, run the following command to train the model: + +```bash +python tools/train.py projects/PETR/config/petr/petr_vovnet_gridmask_p4_800x320.py +``` + +### Testing commands + +In MMDet3D's root directory, run the following command to test the model: + +```bash +python tools/test.py projects/PETR/config/petr/petr_vovnet_gridmask_p4_800x320.py ${CHECKPOINT_PATH} +``` + +## Results + + + +This Result is trained by petr_vovnet_gridmask_p4_800x320.py and use [weights](https://drive.google.com/file/d/1ABI5BoQCkCkP4B0pO5KBJ3Ni0tei0gZi/view?usp=sharing) as pretrain weight. + +| Backbone | Lr schd | Mem (GB) | Inf time (fps) | mAP | NDS | Download | +| :----------------------------------------------------------------------------------------------: | :-----: | :------: | :------------: | :--: | :--: | :----------------------: | +| [petr_vovnet_gridmask_p4_800x320](projects/PETR/configs/petr/petr_vovnet_gridmask_p4_800x320.py) | 1x | 7.62 | 18.7 | 38.3 | 43.5 | [model](<>) \| [log](<>) | + +``` +mAP: 0.3830 +mATE: 0.7547 +mASE: 0.2683 +mAOE: 0.4948 +mAVE: 0.8331 +mAAE: 0.2056 +NDS: 0.4358 +Eval time: 118.7s + +Per-class results: +Object Class AP ATE ASE AOE AVE AAE +car 0.567 0.538 0.151 0.086 0.873 0.212 +truck 0.341 0.785 0.213 0.113 0.821 0.234 +bus 0.426 0.766 0.201 0.128 1.813 0.343 +trailer 0.216 1.116 0.227 0.649 0.640 0.122 +construction_vehicle 0.093 1.118 0.483 1.292 0.217 0.330 +pedestrian 0.453 0.685 0.293 0.644 0.535 0.238 +motorcycle 0.374 0.700 0.253 0.624 1.291 0.154 +bicycle 0.345 0.622 0.262 0.775 0.475 0.011 +traffic_cone 0.539 0.557 0.319 nan nan nan +barrier 0.476 0.661 0.279 0.142 nan nan +``` diff --git a/projects/PETR/config/petr/petr_vovnet_gridmask_p4_800x320.py b/projects/PETR/config/petr/petr_vovnet_gridmask_p4_800x320.py new file mode 100644 index 0000000000..06898307e1 --- /dev/null +++ b/projects/PETR/config/petr/petr_vovnet_gridmask_p4_800x320.py @@ -0,0 +1,357 @@ +_base_ = [ + 'mmdet3d::_base_/datasets/nus-3d.py', 'mmdet3d::_base_/default_runtime.py', + 'mmdet3d::_base_/schedules/cyclic-20e.py' +] +backbone_norm_cfg = dict(type='LN', requires_grad=True) +custom_imports = dict(imports=['projects.PETR.petr']) + +randomness = dict(seed=1, deterministic=False, diff_rank_seed=False) +# If point cloud range is changed, the models should also change their point +# cloud range accordingly +point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] +voxel_size = [0.2, 0.2, 8] +img_norm_cfg = dict( + mean=[103.530, 116.280, 123.675], + std=[57.375, 57.120, 58.395], + to_rgb=False) +# For nuScenes we usually do 10-class detection +class_names = [ + 'car', 'truck', 'construction_vehicle', 'bus', 'trailer', 'barrier', + 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone' +] +metainfo = dict(classes=class_names) + +input_modality = dict(use_camera=True) +model = dict( + type='PETR', + data_preprocessor=dict( + type='Det3DDataPreprocessor', + mean=[103.530, 116.280, 123.675], + std=[57.375, 57.120, 58.395], + bgr_to_rgb=False, + pad_size_divisor=32), + use_grid_mask=True, + img_backbone=dict( + type='VoVNetCP', + spec_name='V-99-eSE', + norm_eval=True, + frozen_stages=-1, + input_ch=3, + out_features=( + 'stage4', + 'stage5', + )), + img_neck=dict( + type='CPFPN', in_channels=[768, 1024], out_channels=256, num_outs=2), + pts_bbox_head=dict( + type='PETRHead', + num_classes=10, + in_channels=256, + num_query=900, + LID=True, + with_position=True, + with_multiview=True, + position_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + normedlinear=False, + transformer=dict( + type='PETRTransformer', + decoder=dict( + type='PETRTransformerDecoder', + return_intermediate=True, + num_layers=6, + transformerlayers=dict( + type='PETRTransformerDecoderLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.1, + dropout_layer=dict(type='Dropout', drop_prob=0.1)), + dict( + type='PETRMultiheadAttention', + embed_dims=256, + num_heads=8, + attn_drop=0.1, + dropout_layer=dict(type='Dropout', drop_prob=0.1)), + ], + feedforward_channels=2048, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm')), + )), + bbox_coder=dict( + type='NMSFreeCoder', + post_center_range=[-61.2, -61.2, -10.0, 61.2, 61.2, 10.0], + pc_range=point_cloud_range, + max_num=300, + voxel_size=voxel_size, + num_classes=10), + positional_encoding=dict( + type='SinePositionalEncoding3D', num_feats=128, normalize=True), + loss_cls=dict( + type='mmdet.FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0), + loss_bbox=dict(type='mmdet.L1Loss', loss_weight=0.25), + loss_iou=dict(type='mmdet.GIoULoss', loss_weight=0.0)), + # model training and testing settings + train_cfg=dict( + pts=dict( + grid_size=[512, 512, 1], + voxel_size=voxel_size, + point_cloud_range=point_cloud_range, + out_size_factor=4, + assigner=dict( + type='HungarianAssigner3D', + cls_cost=dict(type='FocalLossCost', weight=2.0), + reg_cost=dict(type='BBox3DL1Cost', weight=0.25), + iou_cost=dict( + type='IoUCost', weight=0.0 + ), # Fake cost. Just to be compatible with DETR head. + pc_range=point_cloud_range)))) + +dataset_type = 'NuScenesDataset' +data_root = 'data/nuscenes/' +file_client_args = dict(backend='disk') + +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'nuscenes_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict( + car=5, + truck=5, + bus=5, + trailer=5, + construction_vehicle=5, + traffic_cone=5, + barrier=5, + motorcycle=5, + bicycle=5, + pedestrian=5)), + classes=class_names, + sample_groups=dict( + car=2, + truck=3, + construction_vehicle=7, + bus=4, + trailer=6, + barrier=2, + motorcycle=6, + bicycle=6, + pedestrian=2, + traffic_cone=2), + points_loader=dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=5, + use_dim=[0, 1, 2, 3, 4])) +ida_aug_conf = { + 'resize_lim': (0.47, 0.625), + 'final_dim': (320, 800), + 'bot_pct_lim': (0.0, 0.0), + 'rot_lim': (0.0, 0.0), + 'H': 900, + 'W': 1600, + 'rand_flip': True, +} + +train_pipeline = [ + dict(type='LoadMultiViewImageFromFiles', to_float32=True), + dict( + type='LoadAnnotations3D', + with_bbox_3d=True, + with_label_3d=True, + with_attr_label=False), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectNameFilter', classes=class_names), + dict( + type='ResizeCropFlipImage', data_aug_conf=ida_aug_conf, training=True), + dict( + type='GlobalRotScaleTransImage', + rot_range=[-0.3925, 0.3925], + translation_std=[0, 0, 0], + scale_ratio_range=[0.95, 1.05], + reverse_angle=False, + training=True), + dict( + type='Pack3DDetInputs', + keys=[ + 'img', 'gt_bboxes', 'gt_bboxes_labels', 'attr_labels', + 'gt_bboxes_3d', 'gt_labels_3d', 'centers_2d', 'depths' + ]) +] +test_pipeline = [ + dict(type='LoadMultiViewImageFromFiles', to_float32=True), + dict( + type='ResizeCropFlipImage', data_aug_conf=ida_aug_conf, + training=False), + dict(type='Pack3DDetInputs', keys=['img']) +] + +train_dataloader = dict( + batch_size=1, + num_workers=4, + dataset=dict( + type=dataset_type, + data_prefix=dict( + pts='samples/LIDAR_TOP', + CAM_FRONT='samples/CAM_FRONT', + CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT', + CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT', + CAM_BACK='samples/CAM_BACK', + CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT', + CAM_BACK_LEFT='samples/CAM_BACK_LEFT'), + pipeline=train_pipeline, + box_type_3d='LiDAR', + metainfo=metainfo, + test_mode=False, + modality=input_modality, + use_valid_flag=True)) +test_dataloader = dict( + dataset=dict( + type=dataset_type, + data_prefix=dict( + pts='samples/LIDAR_TOP', + CAM_FRONT='samples/CAM_FRONT', + CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT', + CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT', + CAM_BACK='samples/CAM_BACK', + CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT', + CAM_BACK_LEFT='samples/CAM_BACK_LEFT'), + pipeline=test_pipeline, + box_type_3d='LiDAR', + metainfo=metainfo, + test_mode=True, + modality=input_modality, + use_valid_flag=True)) +val_dataloader = dict( + dataset=dict( + type=dataset_type, + data_prefix=dict( + pts='samples/LIDAR_TOP', + CAM_FRONT='samples/CAM_FRONT', + CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT', + CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT', + CAM_BACK='samples/CAM_BACK', + CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT', + CAM_BACK_LEFT='samples/CAM_BACK_LEFT'), + pipeline=test_pipeline, + box_type_3d='LiDAR', + metainfo=metainfo, + test_mode=True, + modality=input_modality, + use_valid_flag=True)) + +# Different from original PETR: +# We don't use special lr for image_backbone +# This seems won't affect model performance +optim_wrapper = dict( + # TODO Add Amp + # type='AmpOptimWrapper', + # loss_scale='dynamic', + optimizer=dict(type='AdamW', lr=2e-4, weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'img_backbone': dict(lr_mult=0.1), + }), + clip_grad=dict(max_norm=35, norm_type=2)) + +num_epochs = 24 + +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0 / 3, + begin=0, + end=500, + by_epoch=False), + dict( + type='CosineAnnealingLR', + # TODO Figure out what T_max + T_max=num_epochs, + by_epoch=True, + ) +] + +train_cfg = dict(max_epochs=num_epochs, val_interval=num_epochs) + +find_unused_parameters = False + +# pretrain_path can be found here: +# https://drive.google.com/file/d/1ABI5BoQCkCkP4B0pO5KBJ3Ni0tei0gZi/view +load_from = '/mnt/d/fcos3d_vovnet_imgbackbone-remapped.pth' +resume = False + +# --------------Original--------------- +# mAP: 0.3778 +# mATE: 0.7463 +# mASE: 0.2718 +# mAOE: 0.4883 +# mAVE: 0.9062 +# mAAE: 0.2123 +# NDS: 0.4264 +# Eval time: 242.1s + +# Per-class results: +# Object Class AP ATE ASE AOE AVE AAE +# car 0.556 0.555 0.153 0.091 0.917 0.216 +# truck 0.330 0.805 0.218 0.119 0.859 0.250 +# bus 0.412 0.789 0.205 0.162 2.067 0.337 +# trailer 0.221 0.976 0.233 0.663 0.797 0.146 +# construction_vehicle 0.094 1.096 0.493 1.145 0.190 0.349 +# pedestrian 0.453 0.688 0.289 0.636 0.549 0.235 +# motorcycle 0.368 0.690 0.256 0.622 1.417 0.149 +# bicycle 0.341 0.609 0.270 0.812 0.455 0.017 +# traffic_cone 0.531 0.582 0.320 nan nan nan +# barrier 0.472 0.673 0.281 0.145 nan nan + +# --------------Refactored in mmdet3d v1.0--------------- +# mAP: 0.3827 +# mATE: 0.7375 +# mASE: 0.2703 +# mAOE: 0.4799 +# mAVE: 0.8699 +# mAAE: 0.2038 +# NDS: 0.4352 +# Eval time: 124.8s + +# Per-class results: +# Object Class AP ATE ASE AOE AVE AAE +# car 0.574 0.519 0.150 0.087 0.865 0.206 +# truck 0.349 0.773 0.213 0.117 0.855 0.220 +# bus 0.423 0.781 0.204 0.122 1.902 0.319 +# trailer 0.219 1.034 0.231 0.608 0.830 0.149 +# construction_vehicle 0.084 1.062 0.486 1.245 0.172 0.360 +# pedestrian 0.452 0.681 0.293 0.646 0.529 0.231 +# motorcycle 0.378 0.670 0.250 0.567 1.334 0.130 +# bicycle 0.347 0.639 0.264 0.788 0.472 0.016 +# traffic_cone 0.538 0.553 0.325 nan nan nan +# barrier 0.464 0.662 0.287 0.137 nan nan + +# --------------Refactored in mmdet3d v1.1--------------- +# mAP: 0.3830 +# mATE: 0.7547 +# mASE: 0.2683 +# mAOE: 0.4948 +# mAVE: 0.8331 +# mAAE: 0.2056 +# NDS: 0.4358 +# Eval time: 118.7s + +# Per-class results: +# Object Class AP ATE ASE AOE AVE AAE +# car 0.567 0.538 0.151 0.086 0.873 0.212 +# truck 0.341 0.785 0.213 0.113 0.821 0.234 +# bus 0.426 0.766 0.201 0.128 1.813 0.343 +# trailer 0.216 1.116 0.227 0.649 0.640 0.122 +# construction_vehicle 0.093 1.118 0.483 1.292 0.217 0.330 +# pedestrian 0.453 0.685 0.293 0.644 0.535 0.238 +# motorcycle 0.374 0.700 0.253 0.624 1.291 0.154 +# bicycle 0.345 0.622 0.262 0.775 0.475 0.011 +# traffic_cone 0.539 0.557 0.319 nan nan nan +# barrier 0.476 0.661 0.279 0.142 nan nan diff --git a/projects/PETR/petr/__init__.py b/projects/PETR/petr/__init__.py new file mode 100644 index 0000000000..2ed2ecc908 --- /dev/null +++ b/projects/PETR/petr/__init__.py @@ -0,0 +1,24 @@ +from .cp_fpn import CPFPN +from .hungarian_assigner_3d import HungarianAssigner3D +from .match_cost import BBox3DL1Cost +from .nms_free_coder import NMSFreeCoder +from .petr import PETR +from .petr_head import PETRHead +from .petr_transformer import (PETRDNTransformer, PETRMultiheadAttention, + PETRTransformer, PETRTransformerDecoder, + PETRTransformerDecoderLayer, + PETRTransformerEncoder) +from .positional_encoding import (LearnedPositionalEncoding3D, + SinePositionalEncoding3D) +from .transforms_3d import GlobalRotScaleTransImage, ResizeCropFlipImage +from .utils import denormalize_bbox, normalize_bbox +from .vovnetcp import VoVNetCP + +__all__ = [ + 'GlobalRotScaleTransImage', 'ResizeCropFlipImage', 'VoVNetCP', 'PETRHead', + 'CPFPN', 'HungarianAssigner3D', 'NMSFreeCoder', 'BBox3DL1Cost', + 'LearnedPositionalEncoding3D', 'PETRDNTransformer', + 'PETRMultiheadAttention', 'PETRTransformer', 'PETRTransformerDecoder', + 'PETRTransformerDecoderLayer', 'PETRTransformerEncoder', 'PETR', + 'SinePositionalEncoding3D', 'denormalize_bbox', 'normalize_bbox' +] diff --git a/projects/PETR/petr/cp_fpn.py b/projects/PETR/petr/cp_fpn.py new file mode 100644 index 0000000000..02c902485b --- /dev/null +++ b/projects/PETR/petr/cp_fpn.py @@ -0,0 +1,211 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from mmdetection (https://github.com/open-mmlab/mmdetection) +# Copyright (c) OpenMMLab. All rights reserved. +# ------------------------------------------------------------------------ +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule +from mmengine.model import BaseModule + +from mmdet3d.registry import MODELS + + +# This FPN remove unused parameters which can used with checkpoint +# (with_cp = True) +@MODELS.register_module() +class CPFPN(BaseModule): + r"""Feature Pyramid Network. + + This is an implementation of paper `Feature Pyramid Networks for Object + Detection `_. + + Args: + in_channels (List[int]): Number of input channels per scale. + out_channels (int): Number of output channels (used at each scale) + num_outs (int): Number of output scales. + start_level (int): Index of the start input backbone level used to + build the feature pyramid. Default: 0. + end_level (int): Index of the end input backbone level (exclusive) to + build the feature pyramid. Default: -1, which means the last level. + add_extra_convs (bool | str): If bool, it decides whether to add conv + layers on top of the original feature maps. Default to False. + If True, it is equivalent to `add_extra_convs='on_input'`. + If str, it specifies the source feature map of the extra convs. + Only the following options are allowed + + - 'on_input': Last feat map of neck inputs (i.e. backbone feature). + - 'on_lateral': Last feature map after lateral convs. + - 'on_output': The last output feature map after fpn convs. + relu_before_extra_convs (bool): Whether to apply relu before the extra + conv. Default: False. + no_norm_on_lateral (bool): Whether to apply norm on lateral. + Default: False. + conv_cfg (dict): Config dict for convolution layer. Default: None. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (str): Config dict for activation layer in ConvModule. + Default: None. + upsample_cfg (dict): Config dict for interpolate layer. + Default: `dict(mode='nearest')` + init_cfg (dict or list[dict], optional): Initialization config dict. + + Example: + >>> import torch + >>> in_channels = [2, 3, 5, 7] + >>> scales = [340, 170, 84, 43] + >>> inputs = [torch.rand(1, c, s, s) + ... for c, s in zip(in_channels, scales)] + >>> self = FPN(in_channels, 11, len(in_channels)).eval() + >>> outputs = self.forward(inputs) + >>> for i in range(len(outputs)): + ... print(f'outputs[{i}].shape = {outputs[i].shape}') + outputs[0].shape = torch.Size([1, 11, 340, 340]) + outputs[1].shape = torch.Size([1, 11, 170, 170]) + outputs[2].shape = torch.Size([1, 11, 84, 84]) + outputs[3].shape = torch.Size([1, 11, 43, 43]) + """ + + def __init__(self, + in_channels, + out_channels, + num_outs, + start_level=0, + end_level=-1, + add_extra_convs=False, + relu_before_extra_convs=False, + no_norm_on_lateral=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=None, + upsample_cfg=dict(mode='nearest'), + init_cfg=dict( + type='Xavier', layer='Conv2d', distribution='uniform')): + super(CPFPN, self).__init__(init_cfg) + assert isinstance(in_channels, list) + self.in_channels = in_channels + self.out_channels = out_channels + self.num_ins = len(in_channels) + self.num_outs = num_outs + self.relu_before_extra_convs = relu_before_extra_convs + self.no_norm_on_lateral = no_norm_on_lateral + self.fp16_enabled = False + self.upsample_cfg = upsample_cfg.copy() + + if end_level == -1: + self.backbone_end_level = self.num_ins + assert num_outs >= self.num_ins - start_level + else: + # if end_level < inputs, no extra level is allowed + self.backbone_end_level = end_level + assert end_level <= len(in_channels) + assert num_outs == end_level - start_level + self.start_level = start_level + self.end_level = end_level + self.add_extra_convs = add_extra_convs + assert isinstance(add_extra_convs, (str, bool)) + if isinstance(add_extra_convs, str): + # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' + assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') + elif add_extra_convs: # True + self.add_extra_convs = 'on_input' + + self.lateral_convs = nn.ModuleList() + self.fpn_convs = nn.ModuleList() + + for i in range(self.start_level, self.backbone_end_level): + l_conv = ConvModule( + in_channels[i], + out_channels, + 1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, + act_cfg=act_cfg, + inplace=False) + self.lateral_convs.append(l_conv) + if i == 0: + fpn_conv = ConvModule( + out_channels, + out_channels, + 3, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(fpn_conv) + + # add extra conv layers (e.g., RetinaNet) + extra_levels = num_outs - self.backbone_end_level + self.start_level + if self.add_extra_convs and extra_levels >= 1: + for i in range(extra_levels): + if i == 0 and self.add_extra_convs == 'on_input': + in_channels = self.in_channels[self.backbone_end_level - 1] + else: + in_channels = out_channels + extra_fpn_conv = ConvModule( + in_channels, + out_channels, + 3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + inplace=False) + self.fpn_convs.append(extra_fpn_conv) + + # @auto_fp16() + def forward(self, inputs): + """Forward function.""" + assert len(inputs) == len(self.in_channels) + + # build laterals + laterals = [ + lateral_conv(inputs[i + self.start_level]) + for i, lateral_conv in enumerate(self.lateral_convs) + ] + + # build top-down path + used_backbone_levels = len(laterals) + for i in range(used_backbone_levels - 1, 0, -1): + # In some cases, fixing `scale factor` (e.g. 2) is preferred, but + # it cannot co-exist with `size` in `F.interpolate`. + if 'scale_factor' in self.upsample_cfg: + laterals[i - 1] += F.interpolate(laterals[i], + **self.upsample_cfg) + else: + prev_shape = laterals[i - 1].shape[2:] + laterals[i - 1] += F.interpolate( + laterals[i], size=prev_shape, **self.upsample_cfg) + + # build outputs + # part 1: from original levels + outs = [ + self.fpn_convs[i](laterals[i]) if i == 0 else laterals[i] + for i in range(used_backbone_levels) + ] + # part 2: add extra levels + if self.num_outs > len(outs): + # use max pool to get more levels on top of outputs + # (e.g., Faster R-CNN, Mask R-CNN) + if not self.add_extra_convs: + for i in range(self.num_outs - used_backbone_levels): + outs.append(F.max_pool2d(outs[-1], 1, stride=2)) + # add conv layers on top of original feature maps (RetinaNet) + else: + if self.add_extra_convs == 'on_input': + extra_source = inputs[self.backbone_end_level - 1] + elif self.add_extra_convs == 'on_lateral': + extra_source = laterals[-1] + elif self.add_extra_convs == 'on_output': + extra_source = outs[-1] + else: + raise NotImplementedError + outs.append(self.fpn_convs[used_backbone_levels](extra_source)) + for i in range(used_backbone_levels + 1, self.num_outs): + if self.relu_before_extra_convs: + outs.append(self.fpn_convs[i](F.relu(outs[-1]))) + else: + outs.append(self.fpn_convs[i](outs[-1])) + return tuple(outs) diff --git a/projects/PETR/petr/grid_mask.py b/projects/PETR/petr/grid_mask.py new file mode 100644 index 0000000000..279d6b2b17 --- /dev/null +++ b/projects/PETR/petr/grid_mask.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +import torch.nn as nn +from PIL import Image + + +class Grid(object): + + def __init__(self, + use_h, + use_w, + rotate=1, + offset=False, + ratio=0.5, + mode=0, + prob=1., + length=1): + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.st_prob = prob + self.prob = prob + self.length = length + + def set_prob(self, epoch, max_epoch): + self.prob = self.st_prob * epoch / max_epoch + + def __call__(self, img, label): + if np.random.rand() > self.prob: + return img, label + h = img.size(1) + w = img.size(2) + self.d1 = 2 + self.d2 = min(h, w) + hh = int(1.5 * h) + ww = int(1.5 * w) + d = np.random.randint(self.d1, self.d2) + if self.ratio == 1: + self.length = np.random.randint(1, d) + else: + self.length = min(max(int(d * self.ratio + 0.5), 1), d - 1) + mask = np.ones((hh, ww), np.float32) + st_h = np.random.randint(d) + st_w = np.random.randint(d) + if self.use_h: + for i in range(hh // d): + s = d * i + st_h + t = min(s + self.length, hh) + mask[s:t, :] *= 0 + if self.use_w: + for i in range(ww // d): + s = d * i + st_w + t = min(s + self.length, ww) + mask[:, s:t] *= 0 + + r = np.random.randint(self.rotate) + mask = Image.fromarray(np.uint8(mask)) + mask = mask.rotate(r) + mask = np.asarray(mask) + mask = mask[(hh - h) // 2:(hh - h) // 2 + h, + (ww - w) // 2:(ww - w) // 2 + w] + + mask = torch.from_numpy(mask).float() + if self.mode == 1: + mask = 1 - mask + + mask = mask.expand_as(img) + if self.offset: + offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float() + offset = (1 - mask) * offset + img = img * mask + offset + else: + img = img * mask + + return img, label + + +class GridMask(nn.Module): + + def __init__(self, + use_h, + use_w, + rotate=1, + offset=False, + ratio=0.5, + mode=0, + prob=1.): + super(GridMask, self).__init__() + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.st_prob = prob + self.prob = prob + + def set_prob(self, epoch, max_epoch): + self.prob = self.st_prob * epoch / max_epoch # + 1.#0.5 + + def forward(self, x): + if np.random.rand() > self.prob or not self.training: + return x + n, c, h, w = x.size() + x = x.view(-1, h, w) + hh = int(1.5 * h) + ww = int(1.5 * w) + d = np.random.randint(2, h) + self.length = min(max(int(d * self.ratio + 0.5), 1), d - 1) + mask = np.ones((hh, ww), np.float32) + st_h = np.random.randint(d) + st_w = np.random.randint(d) + if self.use_h: + for i in range(hh // d): + s = d * i + st_h + t = min(s + self.length, hh) + mask[s:t, :] *= 0 + if self.use_w: + for i in range(ww // d): + s = d * i + st_w + t = min(s + self.length, ww) + mask[:, s:t] *= 0 + + r = np.random.randint(self.rotate) + mask = Image.fromarray(np.uint8(mask)) + mask = mask.rotate(r) + mask = np.asarray(mask) + mask = mask[(hh - h) // 2:(hh - h) // 2 + h, + (ww - w) // 2:(ww - w) // 2 + w] + + mask = torch.from_numpy(mask).float().cuda() + if self.mode == 1: + mask = 1 - mask + mask = mask.expand_as(x) + if self.offset: + offset = torch.from_numpy( + 2 * (np.random.rand(h, w) - 0.5)).float().cuda() + x = x * mask + offset * (1 - mask) + else: + x = x * mask + + return x.view(n, c, h, w) diff --git a/projects/PETR/petr/hungarian_assigner_3d.py b/projects/PETR/petr/hungarian_assigner_3d.py new file mode 100644 index 0000000000..860032324f --- /dev/null +++ b/projects/PETR/petr/hungarian_assigner_3d.py @@ -0,0 +1,142 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2021 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR3D (https://github.com/WangYueFt/detr3d) +# Copyright (c) 2021 Wang, Yue +# ------------------------------------------------------------------------ +# Modified from mmdetection (https://github.com/open-mmlab/mmdetection) +# Copyright (c) OpenMMLab. All rights reserved. +# ------------------------------------------------------------------------ +import torch +from mmdet.models.task_modules import AssignResult, BaseAssigner + +from mmdet3d.registry import TASK_UTILS +from projects.PETR.petr.utils import normalize_bbox + +try: + from scipy.optimize import linear_sum_assignment +except ImportError: + linear_sum_assignment = None + + +@TASK_UTILS.register_module() +class HungarianAssigner3D(BaseAssigner): + """Computes one-to-one matching between predictions and ground truth. This + class computes an assignment between the targets and the predictions based + on the costs. The costs are weighted sum of three components: + classification cost, regression L1 cost and regression iou cost. The + targets don't include the no_object, so generally there are more + predictions than targets. After the one-to-one matching, the un-matched are + treated as backgrounds. Thus each query prediction will be assigned with + `0` or a positive integer indicating the ground truth index: + + - 0: negative sample, no assigned gt + - positive integer: positive sample, index (1-based) of assigned gt + Args: + cls_weight (int | float, optional): The scale factor for classification + cost. Default 1.0. + bbox_weight (int | float, optional): The scale factor for regression + L1 cost. Default 1.0. + iou_weight (int | float, optional): The scale factor for regression + iou cost. Default 1.0. + iou_calculator (dict | optional): The config for the iou calculation. + Default type `BboxOverlaps2D`. + iou_mode (str | optional): "iou" (intersection over union), "iof" + (intersection over foreground), or "giou" (generalized + intersection over union). Default "giou". + """ + + def __init__(self, + cls_cost=dict(type='ClassificationCost', weight=1.), + reg_cost=dict(type='BBoxL1Cost', weight=1.0), + iou_cost=dict(type='IoUCost', weight=0.0), + pc_range=None): + self.cls_cost = TASK_UTILS.build(cls_cost) + self.reg_cost = TASK_UTILS.build(reg_cost) + self.iou_cost = TASK_UTILS.build(iou_cost) + self.pc_range = pc_range + + def assign(self, + bbox_pred, + cls_pred, + gt_bboxes, + gt_labels, + gt_bboxes_ignore=None, + eps=1e-7): + """Computes one-to-one matching based on the weighted costs. + This method assign each query prediction to a ground truth or + background. The `assigned_gt_inds` with -1 means don't care, + 0 means negative sample, and positive number is the index (1-based) + of assigned gt. + The assignment is done in the following steps, the order matters. + 1. assign every prediction to -1 + 2. compute the weighted costs + 3. do Hungarian matching on CPU based on the costs + 4. assign all to 0 (background) first, then for each matched pair + between predictions and gts, treat this prediction as foreground + and assign the corresponding gt index (plus 1) to it. + Args: + bbox_pred (Tensor): Predicted boxes with normalized coordinates + (cx, cy, w, h), which are all in range [0, 1]. Shape + [num_query, 4]. + cls_pred (Tensor): Predicted classification logits, shape + [num_query, num_class]. + gt_bboxes (Tensor): Ground truth boxes with unnormalized + coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are + labelled as `ignored`. Default None. + eps (int | float, optional): A value added to the denominator for + numerical stability. Default 1e-7. + Returns: + :obj:`AssignResult`: The assigned result. + """ + assert gt_bboxes_ignore is None, \ + 'Only case when gt_bboxes_ignore is None is supported.' + num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0) + + # 1. assign -1 by default + assigned_gt_inds = bbox_pred.new_full((num_bboxes, ), + -1, + dtype=torch.long) + assigned_labels = bbox_pred.new_full((num_bboxes, ), + -1, + dtype=torch.long) + if num_gts == 0 or num_bboxes == 0: + # No ground truth or boxes, return empty assignment + if num_gts == 0: + # No ground truth, assign all to background + assigned_gt_inds[:] = 0 + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) + + # 2. compute the weighted costs + # classification and bboxcost. + cls_cost = self.cls_cost(cls_pred, gt_labels) + # regression L1 cost + normalized_gt_bboxes = normalize_bbox(gt_bboxes, self.pc_range) + reg_cost = self.reg_cost(bbox_pred[:, :8], normalized_gt_bboxes[:, :8]) + + # weighted sum of above two costs + cost = cls_cost + reg_cost + + # 3. do Hungarian matching on CPU using linear_sum_assignment + cost = cost.detach().cpu() + if linear_sum_assignment is None: + raise ImportError('Please run "pip install scipy" ' + 'to install scipy first.') + cost = torch.nan_to_num(cost, nan=100.0, posinf=100.0, neginf=-100.0) + matched_row_inds, matched_col_inds = linear_sum_assignment(cost) + matched_row_inds = torch.from_numpy(matched_row_inds).to( + bbox_pred.device) + matched_col_inds = torch.from_numpy(matched_col_inds).to( + bbox_pred.device) + + # 4. assign backgrounds and foregrounds + # assign all indices to backgrounds first + assigned_gt_inds[:] = 0 + # assign foregrounds based on matching results + assigned_gt_inds[matched_row_inds] = matched_col_inds + 1 + assigned_labels[matched_row_inds] = gt_labels[matched_col_inds] + return AssignResult( + num_gts, assigned_gt_inds, None, labels=assigned_labels) diff --git a/projects/PETR/petr/match_cost.py b/projects/PETR/petr/match_cost.py new file mode 100644 index 0000000000..ee48d4ba4b --- /dev/null +++ b/projects/PETR/petr/match_cost.py @@ -0,0 +1,338 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdet3d.registry import TASK_UTILS + + +def fp16_clamp(x, min=None, max=None): + if not x.is_cuda and x.dtype == torch.float16: + # clamp for cpu float16, tensor fp16 has no clamp implementation + return x.float().clamp(min, max).half() + + return x.clamp(min, max) + + +def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6): + """Calculate overlap between two set of bboxes. + FP16 Contributed by https://github.com/open-mmlab/mmdetection/pull/4889 + Note: + Assume bboxes1 is M x 4, bboxes2 is N x 4, when mode is 'iou', + there are some new generated variable when calculating IOU + using bbox_overlaps function: + 1) is_aligned is False + area1: M x 1 + area2: N x 1 + lt: M x N x 2 + rb: M x N x 2 + wh: M x N x 2 + overlap: M x N x 1 + union: M x N x 1 + ious: M x N x 1 + Total memory: + S = (9 x N x M + N + M) * 4 Byte, + When using FP16, we can reduce: + R = (9 x N x M + N + M) * 4 / 2 Byte + R large than (N + M) * 4 * 2 is always true when N and M >= 1. + Obviously, N + M <= N * M < 3 * N * M, when N >=2 and M >=2, + N + 1 < 3 * N, when N or M is 1. + Given M = 40 (ground truth), N = 400000 (three anchor boxes + in per grid, FPN, R-CNNs), + R = 275 MB (one times) + A special case (dense detection), M = 512 (ground truth), + R = 3516 MB = 3.43 GB + When the batch size is B, reduce: + B x R + Therefore, CUDA memory runs out frequently. + Experiments on GeForce RTX 2080Ti (11019 MiB): + | dtype | M | N | Use | Real | Ideal | + |:----:|:----:|:----:|:----:|:----:|:----:| + | FP32 | 512 | 400000 | 8020 MiB | -- | -- | + | FP16 | 512 | 400000 | 4504 MiB | 3516 MiB | 3516 MiB | + | FP32 | 40 | 400000 | 1540 MiB | -- | -- | + | FP16 | 40 | 400000 | 1264 MiB | 276MiB | 275 MiB | + 2) is_aligned is True + area1: N x 1 + area2: N x 1 + lt: N x 2 + rb: N x 2 + wh: N x 2 + overlap: N x 1 + union: N x 1 + ious: N x 1 + Total memory: + S = 11 x N * 4 Byte + When using FP16, we can reduce: + R = 11 x N * 4 / 2 Byte + So do the 'giou' (large than 'iou'). + Time-wise, FP16 is generally faster than FP32. + When gpu_assign_thr is not -1, it takes more time on cpu + but not reduce memory. + There, we can reduce half the memory and keep the speed. + If ``is_aligned`` is ``False``, then calculate the overlaps between each + bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned + pair of bboxes1 and bboxes2. + Args: + bboxes1 (Tensor): shape (B, m, 4) in format or empty. + bboxes2 (Tensor): shape (B, n, 4) in format or empty. + B indicates the batch dim, in shape (B1, B2, ..., Bn). + If ``is_aligned`` is ``True``, then m and n must be equal. + mode (str): "iou" (intersection over union), "iof" (intersection over + foreground) or "giou" (generalized intersection over union). + Default "iou". + is_aligned (bool, optional): If True, then m and n must be equal. + Default False. + eps (float, optional): A value added to the denominator for numerical + stability. Default 1e-6. + Returns: + Tensor: shape (m, n) if ``is_aligned`` is False else shape (m,) + Example: + >>> bboxes1 = torch.FloatTensor([ + >>> [0, 0, 10, 10], + >>> [10, 10, 20, 20], + >>> [32, 32, 38, 42], + >>> ]) + >>> bboxes2 = torch.FloatTensor([ + >>> [0, 0, 10, 20], + >>> [0, 10, 10, 19], + >>> [10, 10, 20, 20], + >>> ]) + >>> overlaps = bbox_overlaps(bboxes1, bboxes2) + >>> assert overlaps.shape == (3, 3) + >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True) + >>> assert overlaps.shape == (3, ) + Example: + >>> empty = torch.empty(0, 4) + >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]]) + >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) + >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) + >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) + """ + + assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}' + # Either the boxes are empty or the length of boxes' last dimension is 4 + assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0) + assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0) + + # Batch dim must be the same + # Batch dim: (B1, B2, ... Bn) + assert bboxes1.shape[:-2] == bboxes2.shape[:-2] + batch_shape = bboxes1.shape[:-2] + + rows = bboxes1.size(-2) + cols = bboxes2.size(-2) + if is_aligned: + assert rows == cols + + if rows * cols == 0: + if is_aligned: + return bboxes1.new(batch_shape + (rows, )) + else: + return bboxes1.new(batch_shape + (rows, cols)) + + area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * ( + bboxes1[..., 3] - bboxes1[..., 1]) + area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * ( + bboxes2[..., 3] - bboxes2[..., 1]) + + if is_aligned: + lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2] + rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2] + + wh = fp16_clamp(rb - lt, min=0) + overlap = wh[..., 0] * wh[..., 1] + + if mode in ['iou', 'giou']: + union = area1 + area2 - overlap + else: + union = area1 + if mode == 'giou': + enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2]) + enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:]) + else: + lt = torch.max(bboxes1[..., :, None, :2], + bboxes2[..., None, :, :2]) # [B, rows, cols, 2] + rb = torch.min(bboxes1[..., :, None, 2:], + bboxes2[..., None, :, 2:]) # [B, rows, cols, 2] + + wh = fp16_clamp(rb - lt, min=0) + overlap = wh[..., 0] * wh[..., 1] + + if mode in ['iou', 'giou']: + union = area1[..., None] + area2[..., None, :] - overlap + else: + union = area1[..., None] + if mode == 'giou': + enclosed_lt = torch.min(bboxes1[..., :, None, :2], + bboxes2[..., None, :, :2]) + enclosed_rb = torch.max(bboxes1[..., :, None, 2:], + bboxes2[..., None, :, 2:]) + + eps = union.new_tensor([eps]) + union = torch.max(union, eps) + ious = overlap / union + if mode in ['iou', 'iof']: + return ious + # calculate gious + enclose_wh = fp16_clamp(enclosed_rb - enclosed_lt, min=0) + enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1] + enclose_area = torch.max(enclose_area, eps) + gious = ious - (enclose_area - union) / enclose_area + return gious + + +@TASK_UTILS.register_module() +class BBox3DL1Cost(object): + """BBox3DL1Cost. + + Args: + weight (int | float, optional): loss_weight + """ + + def __init__(self, weight=1.): + self.weight = weight + + def __call__(self, bbox_pred, gt_bboxes): + """ + Args: + bbox_pred (Tensor): Predicted boxes with normalized coordinates + (cx, cy, w, h), which are all in range [0, 1]. Shape + [num_query, 4]. + gt_bboxes (Tensor): Ground truth boxes with normalized + coordinates (x1, y1, x2, y2). Shape [num_gt, 4]. + Returns: + torch.Tensor: bbox_cost value with weight + """ + bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1) + return bbox_cost * self.weight + + +@TASK_UTILS.register_module() +class FocalLossCost: + """FocalLossCost. + Args: + weight (int | float, optional): loss_weight + alpha (int | float, optional): focal_loss alpha + gamma (int | float, optional): focal_loss gamma + eps (float, optional): default 1e-12 + binary_input (bool, optional): Whether the input is binary, + default False. + Examples: + >>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost + >>> import torch + >>> self = FocalLossCost() + >>> cls_pred = torch.rand(4, 3) + >>> gt_labels = torch.tensor([0, 1, 2]) + >>> factor = torch.tensor([10, 8, 10, 8]) + >>> self(cls_pred, gt_labels) + tensor([[-0.3236, -0.3364, -0.2699], + [-0.3439, -0.3209, -0.4807], + [-0.4099, -0.3795, -0.2929], + [-0.1950, -0.1207, -0.2626]]) + """ + + def __init__(self, + weight=1., + alpha=0.25, + gamma=2, + eps=1e-12, + binary_input=False): + self.weight = weight + self.alpha = alpha + self.gamma = gamma + self.eps = eps + self.binary_input = binary_input + + def _focal_loss_cost(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classification logits, shape + (num_query, num_class). + gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,). + Returns: + torch.Tensor: cls_cost value with weight + """ + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels] + return cls_cost * self.weight + + def _mask_focal_loss_cost(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classfication logits + in shape (num_query, d1, ..., dn), dtype=torch.float32. + gt_labels (Tensor): Ground truth in shape (num_gt, d1, ..., dn), + dtype=torch.long. Labels should be binary. + Returns: + Tensor: Focal cost matrix with weight in shape\ + (num_query, num_gt). + """ + cls_pred = cls_pred.flatten(1) + gt_labels = gt_labels.flatten(1).float() + n = cls_pred.shape[1] + cls_pred = cls_pred.sigmoid() + neg_cost = -(1 - cls_pred + self.eps).log() * ( + 1 - self.alpha) * cls_pred.pow(self.gamma) + pos_cost = -(cls_pred + self.eps).log() * self.alpha * ( + 1 - cls_pred).pow(self.gamma) + + cls_cost = torch.einsum('nc,mc->nm', pos_cost, gt_labels) + \ + torch.einsum('nc,mc->nm', neg_cost, (1 - gt_labels)) + return cls_cost / n * self.weight + + def __call__(self, cls_pred, gt_labels): + """ + Args: + cls_pred (Tensor): Predicted classfication logits. + gt_labels (Tensor)): Labels. + Returns: + Tensor: Focal cost matrix with weight in shape\ + (num_query, num_gt). + """ + if self.binary_input: + return self._mask_focal_loss_cost(cls_pred, gt_labels) + else: + return self._focal_loss_cost(cls_pred, gt_labels) + + +@TASK_UTILS.register_module() +class IoUCost: + """IoUCost. + Args: + iou_mode (str, optional): iou mode such as 'iou' | 'giou' + weight (int | float, optional): loss weight + Examples: + >>> from mmdet.core.bbox.match_costs.match_cost import IoUCost + >>> import torch + >>> self = IoUCost() + >>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]]) + >>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]]) + >>> self(bboxes, gt_bboxes) + tensor([[-0.1250, 0.1667], + [ 0.1667, -0.5000]]) + """ + + def __init__(self, iou_mode='giou', weight=1.): + self.weight = weight + self.iou_mode = iou_mode + + def __call__(self, bboxes, gt_bboxes): + """ + Args: + bboxes (Tensor): Predicted boxes with unnormalized coordinates + (x1, y1, x2, y2). Shape (num_query, 4). + gt_bboxes (Tensor): Ground truth boxes with unnormalized + coordinates (x1, y1, x2, y2). Shape (num_gt, 4). + Returns: + torch.Tensor: iou_cost value with weight + """ + # overlaps: [num_bboxes, num_gt] + overlaps = bbox_overlaps( + bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False) + # The 1 is a constant that doesn't change the matching, so omitted. + iou_cost = -overlaps + return iou_cost * self.weight diff --git a/projects/PETR/petr/nms_free_coder.py b/projects/PETR/petr/nms_free_coder.py new file mode 100644 index 0000000000..d1415d4c0e --- /dev/null +++ b/projects/PETR/petr/nms_free_coder.py @@ -0,0 +1,246 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2021 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR3D (https://github.com/WangYueFt/detr3d) +# Copyright (c) 2021 Wang, Yue +# ------------------------------------------------------------------------ +# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d) +# Copyright (c) OpenMMLab. All rights reserved. +# ------------------------------------------------------------------------ +import torch +import torch.nn.functional as F +from mmdet.models.task_modules import BaseBBoxCoder + +from mmdet3d.registry import TASK_UTILS +from projects.PETR.petr.utils import denormalize_bbox + + +@TASK_UTILS.register_module() +class NMSFreeCoder(BaseBBoxCoder): + """Bbox coder for NMS-free detector. + + Args: + pc_range (list[float]): Range of point cloud. + post_center_range (list[float]): Limit of the center. + Default: None. + max_num (int): Max number to be kept. Default: 100. + score_threshold (float): Threshold to filter boxes based on score. + Default: None. + code_size (int): Code size of bboxes. Default: 9 + """ + + def __init__(self, + pc_range, + voxel_size=None, + post_center_range=None, + max_num=100, + score_threshold=None, + num_classes=10): + + self.pc_range = pc_range + self.voxel_size = voxel_size + self.post_center_range = post_center_range + self.max_num = max_num + self.score_threshold = score_threshold + self.num_classes = num_classes + + def encode(self): + pass + + def decode_single(self, cls_scores, bbox_preds): + """Decode bboxes. + + Args: + cls_scores (Tensor): Outputs from the classification head, \ + shape [num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + bbox_preds (Tensor): Outputs from the regression \ + head with normalized coordinate format \ + (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \ + Shape [num_query, 9]. + Returns: + list[dict]: Decoded boxes. + """ + max_num = self.max_num + + cls_scores = cls_scores.sigmoid() + scores, indexes = cls_scores.view(-1).topk(max_num) + labels = indexes % self.num_classes + bbox_index = indexes // self.num_classes + bbox_preds = bbox_preds[bbox_index] + + final_box_preds = denormalize_bbox(bbox_preds, self.pc_range) + final_scores = scores + final_preds = labels + + # use score threshold + if self.score_threshold is not None: + thresh_mask = final_scores > self.score_threshold + if self.post_center_range is not None: + self.post_center_range = torch.tensor( + self.post_center_range, device=scores.device) + + mask = (final_box_preds[..., :3] >= + self.post_center_range[:3]).all(1) + mask &= (final_box_preds[..., :3] <= + self.post_center_range[3:]).all(1) + + if self.score_threshold: + mask &= thresh_mask + + boxes3d = final_box_preds[mask] + scores = final_scores[mask] + labels = final_preds[mask] + predictions_dict = { + 'bboxes': boxes3d, + 'scores': scores, + 'labels': labels + } + + else: + raise NotImplementedError( + 'Need to reorganize output as a batch, only ' + 'support post_center_range is not None for now!') + return predictions_dict + + def decode(self, preds_dicts): + """Decode bboxes. + + Args: + all_cls_scores (Tensor): Outputs from the classification head, \ + shape [nb_dec, bs, num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + all_bbox_preds (Tensor): Sigmoid outputs from the regression \ + head with normalized coordinate format \ + (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \ + Shape [nb_dec, bs, num_query, 9]. + Returns: + list[dict]: Decoded boxes. + """ + all_cls_scores = preds_dicts['all_cls_scores'][-1] + all_bbox_preds = preds_dicts['all_bbox_preds'][-1] + + batch_size = all_cls_scores.size()[0] + predictions_list = [] + for i in range(batch_size): + predictions_list.append( + self.decode_single(all_cls_scores[i], all_bbox_preds[i])) + return predictions_list + + +@TASK_UTILS.register_module() +class NMSFreeClsCoder(BaseBBoxCoder): + """Bbox coder for NMS-free detector. + + Args: + pc_range (list[float]): Range of point cloud. + post_center_range (list[float]): Limit of the center. + Default: None. + max_num (int): Max number to be kept. Default: 100. + score_threshold (float): Threshold to filter boxes based on score. + Default: None. + code_size (int): Code size of bboxes. Default: 9 + """ + + def __init__(self, + pc_range, + voxel_size=None, + post_center_range=None, + max_num=100, + score_threshold=None, + num_classes=10): + + self.pc_range = pc_range + self.voxel_size = voxel_size + self.post_center_range = post_center_range + self.max_num = max_num + self.score_threshold = score_threshold + self.num_classes = num_classes + + def encode(self): + pass + + def decode_single(self, cls_scores, bbox_preds): + """Decode bboxes. + + Args: + cls_scores (Tensor): Outputs from the classification head, \ + shape [num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + bbox_preds (Tensor): Outputs from the regression \ + head with normalized coordinate format \ + (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \ + Shape [num_query, 9]. + Returns: + list[dict]: Decoded boxes. + """ + max_num = self.max_num + + # cls_scores = cls_scores.sigmoid() + # scores, indexes = cls_scores.view(-1).topk(max_num) + # labels = indexes % self.num_classes + # bbox_index = indexes // self.num_classes + # bbox_preds = bbox_preds[bbox_index] + + cls_scores, labels = F.softmax(cls_scores, dim=-1)[..., :-1].max(-1) + scores, indexes = cls_scores.view(-1).topk(max_num) + labels = labels[indexes] + bbox_preds = bbox_preds[indexes] + + final_box_preds = denormalize_bbox(bbox_preds, self.pc_range) + final_scores = scores + final_preds = labels + + # use score threshold + if self.score_threshold is not None: + thresh_mask = final_scores > self.score_threshold + if self.post_center_range is not None: + self.post_center_range = torch.tensor( + self.post_center_range, device=scores.device) + + mask = (final_box_preds[..., :3] >= + self.post_center_range[:3]).all(1) + mask &= (final_box_preds[..., :3] <= + self.post_center_range[3:]).all(1) + + if self.score_threshold: + mask &= thresh_mask + + boxes3d = final_box_preds[mask] + scores = final_scores[mask] + labels = final_preds[mask] + predictions_dict = { + 'bboxes': boxes3d, + 'scores': scores, + 'labels': labels + } + + else: + raise NotImplementedError( + 'Need to reorganize output as a batch, only ' + 'support post_center_range is not None for now!') + return predictions_dict + + def decode(self, preds_dicts): + """Decode bboxes. + + Args: + all_cls_scores (Tensor): Outputs from the classification head, \ + shape [nb_dec, bs, num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + all_bbox_preds (Tensor): Sigmoid outputs from the regression \ + head with normalized coordinate format \ + (cx, cy, w, l, cz, h, rot_sine, rot_cosine, vx, vy). \ + Shape [nb_dec, bs, num_query, 9]. + Returns: + list[dict]: Decoded boxes. + """ + all_cls_scores = preds_dicts['all_cls_scores'][-1] + all_bbox_preds = preds_dicts['all_bbox_preds'][-1] + + batch_size = all_cls_scores.size()[0] + predictions_list = [] + for i in range(batch_size): + predictions_list.append( + self.decode_single(all_cls_scores[i], all_bbox_preds[i])) + return predictions_list diff --git a/projects/PETR/petr/petr.py b/projects/PETR/petr/petr.py new file mode 100644 index 0000000000..030d9d0d15 --- /dev/null +++ b/projects/PETR/petr/petr.py @@ -0,0 +1,299 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR3D (https://github.com/WangYueFt/detr3d) +# Copyright (c) 2021 Wang, Yue +# ------------------------------------------------------------------------ +# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d) +# Copyright (c) OpenMMLab. All rights reserved. +# ------------------------------------------------------------------------ + +import numpy as np +import torch +from mmengine.structures import InstanceData + +import mmdet3d +from mmdet3d.models.detectors.mvx_two_stage import MVXTwoStageDetector +from mmdet3d.registry import MODELS +from mmdet3d.structures.bbox_3d import LiDARInstance3DBoxes, limit_period +from mmdet3d.structures.ops import bbox3d2result +from .grid_mask import GridMask + + +@MODELS.register_module() +class PETR(MVXTwoStageDetector): + """PETR.""" + + def __init__(self, + use_grid_mask=False, + pts_voxel_layer=None, + pts_middle_encoder=None, + pts_fusion_layer=None, + img_backbone=None, + pts_backbone=None, + img_neck=None, + pts_neck=None, + pts_bbox_head=None, + img_roi_head=None, + img_rpn_head=None, + train_cfg=None, + test_cfg=None, + init_cfg=None, + data_preprocessor=None, + **kwargs): + super(PETR, + self).__init__(pts_voxel_layer, pts_middle_encoder, + pts_fusion_layer, img_backbone, pts_backbone, + img_neck, pts_neck, pts_bbox_head, img_roi_head, + img_rpn_head, train_cfg, test_cfg, init_cfg, + data_preprocessor) + self.grid_mask = GridMask( + True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7) + self.use_grid_mask = use_grid_mask + + def extract_img_feat(self, img, img_metas): + """Extract features of images.""" + if isinstance(img, list): + img = torch.stack(img, dim=0) + + B = img.size(0) + if img is not None: + input_shape = img.shape[-2:] + # update real input shape of each single img + for img_meta in img_metas: + img_meta.update(input_shape=input_shape) + if img.dim() == 5: + if img.size(0) == 1 and img.size(1) != 1: + img.squeeze_() + else: + B, N, C, H, W = img.size() + img = img.view(B * N, C, H, W) + if self.use_grid_mask: + img = self.grid_mask(img) + img_feats = self.img_backbone(img) + if isinstance(img_feats, dict): + img_feats = list(img_feats.values()) + else: + return None + if self.with_img_neck: + img_feats = self.img_neck(img_feats) + img_feats_reshaped = [] + for img_feat in img_feats: + BN, C, H, W = img_feat.size() + img_feats_reshaped.append(img_feat.view(B, int(BN / B), C, H, W)) + return img_feats_reshaped + + # @auto_fp16(apply_to=('img'), out_fp32=True) + def extract_feat(self, img, img_metas): + """Extract features from images and points.""" + img_feats = self.extract_img_feat(img, img_metas) + return img_feats + + def forward_pts_train(self, + pts_feats, + gt_bboxes_3d, + gt_labels_3d, + img_metas, + gt_bboxes_ignore=None): + """Forward function for point cloud branch. + + Args: + pts_feats (list[torch.Tensor]): Features of point cloud branch + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth + boxes for each sample. + gt_labels_3d (list[torch.Tensor]): Ground truth labels for + boxes of each sampole + img_metas (list[dict]): Meta information of samples. + gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth + boxes to be ignored. Defaults to None. + Returns: + dict: Losses of each branch. + """ + outs = self.pts_bbox_head(pts_feats, img_metas) + loss_inputs = [gt_bboxes_3d, gt_labels_3d, outs] + losses = self.pts_bbox_head.loss_by_feat(*loss_inputs) + + return losses + + def _forward(self, mode='loss', **kwargs): + """Calls either forward_train or forward_test depending on whether + return_loss=True. + + Note this setting will change the expected inputs. When + `return_loss=True`, img and img_metas are single-nested (i.e. + torch.Tensor and list[dict]), and when `resturn_loss=False`, img and + img_metas should be double nested (i.e. list[torch.Tensor], + list[list[dict]]), with the outer list indicating test time + augmentations. + """ + raise NotImplementedError('tensor mode is yet to add') + + def loss(self, + inputs=None, + data_samples=None, + mode=None, + points=None, + img_metas=None, + gt_bboxes_3d=None, + gt_labels_3d=None, + gt_labels=None, + gt_bboxes=None, + img=None, + proposals=None, + gt_bboxes_ignore=None, + img_depth=None, + img_mask=None): + """Forward training function. + + Args: + points (list[torch.Tensor], optional): Points of each sample. + Defaults to None. + img_metas (list[dict], optional): Meta information of each sample. + Defaults to None. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`], optional): + Ground truth 3D boxes. Defaults to None. + gt_labels_3d (list[torch.Tensor], optional): Ground truth labels + of 3D boxes. Defaults to None. + gt_labels (list[torch.Tensor], optional): Ground truth labels + of 2D boxes in images. Defaults to None. + gt_bboxes (list[torch.Tensor], optional): Ground truth 2D boxes in + images. Defaults to None. + img (torch.Tensor optional): Images of each sample with shape + (N, C, H, W). Defaults to None. + proposals ([list[torch.Tensor], optional): Predicted proposals + used for training Fast RCNN. Defaults to None. + gt_bboxes_ignore (list[torch.Tensor], optional): Ground truth + 2D boxes in images to be ignored. Defaults to None. + Returns: + dict: Losses of different branches. + """ + img = inputs['imgs'] + batch_img_metas = [ds.metainfo for ds in data_samples] + batch_gt_instances_3d = [ds.gt_instances_3d for ds in data_samples] + gt_bboxes_3d = [gt.bboxes_3d for gt in batch_gt_instances_3d] + gt_labels_3d = [gt.labels_3d for gt in batch_gt_instances_3d] + gt_bboxes_ignore = None + + gt_bboxes_3d = self.LidarBox3dVersionTransfrom(gt_bboxes_3d) + + batch_img_metas = self.add_lidar2img(img, batch_img_metas) + + img_feats = self.extract_feat(img=img, img_metas=batch_img_metas) + + losses = dict() + losses_pts = self.forward_pts_train(img_feats, gt_bboxes_3d, + gt_labels_3d, batch_img_metas, + gt_bboxes_ignore) + losses.update(losses_pts) + return losses + + def predict(self, inputs=None, data_samples=None, mode=None, **kwargs): + img = inputs['imgs'] + batch_img_metas = [ds.metainfo for ds in data_samples] + for var, name in [(batch_img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError('{} must be a list, but got {}'.format( + name, type(var))) + img = [img] if img is None else img + + batch_img_metas = self.add_lidar2img(img, batch_img_metas) + + results_list_3d = self.simple_test(batch_img_metas, img, **kwargs) + + for i, data_sample in enumerate(data_samples): + results_list_3d_i = InstanceData( + metainfo=results_list_3d[i]['pts_bbox']) + data_sample.pred_instances_3d = results_list_3d_i + data_sample.pred_instances = InstanceData() + + return data_samples + + def simple_test_pts(self, x, img_metas, rescale=False): + """Test function of point cloud branch.""" + outs = self.pts_bbox_head(x, img_metas) + bbox_list = self.pts_bbox_head.get_bboxes( + outs, img_metas, rescale=rescale) + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results + + def simple_test(self, img_metas, img=None, rescale=False): + """Test function without augmentaiton.""" + img_feats = self.extract_feat(img=img, img_metas=img_metas) + + bbox_list = [dict() for i in range(len(img_metas))] + bbox_pts = self.simple_test_pts(img_feats, img_metas, rescale=rescale) + for result_dict, pts_bbox in zip(bbox_list, bbox_pts): + result_dict['pts_bbox'] = pts_bbox + return bbox_list + + def aug_test_pts(self, feats, img_metas, rescale=False): + feats_list = [] + for j in range(len(feats[0])): + feats_list_level = [] + for i in range(len(feats)): + feats_list_level.append(feats[i][j]) + feats_list.append(torch.stack(feats_list_level, -1).mean(-1)) + outs = self.pts_bbox_head(feats_list, img_metas) + bbox_list = self.pts_bbox_head.get_bboxes( + outs, img_metas, rescale=rescale) + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results + + def aug_test(self, img_metas, imgs=None, rescale=False): + """Test function with augmentaiton.""" + img_feats = self.extract_feats(img_metas, imgs) + img_metas = img_metas[0] + bbox_list = [dict() for i in range(len(img_metas))] + bbox_pts = self.aug_test_pts(img_feats, img_metas, rescale) + for result_dict, pts_bbox in zip(bbox_list, bbox_pts): + result_dict['pts_bbox'] = pts_bbox + return bbox_list + + # may need speed-up + def add_lidar2img(self, img, batch_input_metas): + """add 'lidar2img' transformation matrix into batch_input_metas. + + Args: + batch_input_metas (list[dict]): Meta information of multiple inputs + in a batch. + Returns: + batch_input_metas (list[dict]): Meta info with lidar2img added + """ + lidar2img_rts = [] + for meta in batch_input_metas: + # obtain lidar to image transformation matrix + for i in range(len(meta['cam2img'])): + lidar2cam_rt = torch.tensor(meta['lidar2cam'][i]).double() + intrinsic = torch.tensor(meta['cam2img'][i]).double() + viewpad = torch.eye(4).double() + viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic + lidar2img_rt = (viewpad @ lidar2cam_rt) + # The extrinsics mean the transformation from lidar to camera. + # If anyone want to use the extrinsics as sensor to lidar, + # please use np.linalg.inv(lidar2cam_rt.T) + # and modify the ResizeCropFlipImage + # and LoadMultiViewImageFromMultiSweepsFiles. + lidar2img_rts.append(lidar2img_rt) + meta['lidar2img'] = lidar2img_rts + meta['img_shape'] = [i.shape for i in img[0]] + return batch_input_metas + + def LidarBox3dVersionTransfrom(self, gt_bboxes_3d): + if int(mmdet3d.__version__[0]) >= 1: + # Begin hack adaptation to mmdet3d v1.0 #### + gt_bboxes_3d = gt_bboxes_3d[0].tensor + + gt_bboxes_3d[:, [3, 4]] = gt_bboxes_3d[:, [4, 3]] + gt_bboxes_3d[:, 6] = -gt_bboxes_3d[:, 6] - np.pi / 2 + gt_bboxes_3d[:, 6] = limit_period( + gt_bboxes_3d[:, 6], period=np.pi * 2) + + gt_bboxes_3d = LiDARInstance3DBoxes(gt_bboxes_3d, box_dim=9) + gt_bboxes_3d = [gt_bboxes_3d] + return gt_bboxes_3d diff --git a/projects/PETR/petr/petr_head.py b/projects/PETR/petr/petr_head.py new file mode 100644 index 0000000000..2b6e088e57 --- /dev/null +++ b/projects/PETR/petr/petr_head.py @@ -0,0 +1,825 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR3D (https://github.com/WangYueFt/detr3d) +# Copyright (c) 2021 Wang, Yue +# ------------------------------------------------------------------------ +# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d) +# Copyright (c) OpenMMLab. All rights reserved. +# ------------------------------------------------------------------------ +import math + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Conv2d, Linear +from mmdet.models.dense_heads.anchor_free_head import AnchorFreeHead +from mmdet.models.layers import NormedLinear +from mmdet.models.layers.transformer import inverse_sigmoid +from mmdet.models.utils import multi_apply +from mmengine.model.weight_init import bias_init_with_prob +from mmengine.structures import InstanceData + +from mmdet3d.registry import MODELS, TASK_UTILS +from projects.PETR.petr.utils import normalize_bbox + + +def pos2posemb3d(pos, num_pos_feats=128, temperature=10000): + scale = 2 * math.pi + pos = pos * scale + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + pos_x = pos[..., 0, None] / dim_t + pos_y = pos[..., 1, None] / dim_t + pos_z = pos[..., 2, None] / dim_t + pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), + dim=-1).flatten(-2) + pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), + dim=-1).flatten(-2) + pos_z = torch.stack((pos_z[..., 0::2].sin(), pos_z[..., 1::2].cos()), + dim=-1).flatten(-2) + posemb = torch.cat((pos_y, pos_x, pos_z), dim=-1) + return posemb + + +@MODELS.register_module() +class PETRHead(AnchorFreeHead): + """Implements the DETR transformer head. See `paper: End-to-End Object + Detection with Transformers. + + `_ for details. + Args: + num_classes (int): Number of categories excluding the background. + in_channels (int): Number of channels in the input feature map. + num_query (int): Number of query in Transformer. + num_reg_fcs (int, optional): Number of fully-connected layers used in + `FFN`, which is then used for the regression head. Default 2. + transformer (obj:`mmcv.ConfigDict`|dict): Config for transformer. + Default: None. + sync_cls_avg_factor (bool): Whether to sync the avg_factor of + all ranks. Default to False. + positional_encoding (obj:`mmcv.ConfigDict`|dict): + Config for position encoding. + loss_cls (obj:`mmcv.ConfigDict`|dict): Config of the + classification loss. Default `CrossEntropyLoss`. + loss_bbox (obj:`mmcv.ConfigDict`|dict): Config of the + regression loss. Default `L1Loss`. + loss_iou (obj:`mmcv.ConfigDict`|dict): Config of the + regression iou loss. Default `GIoULoss`. + tran_cfg (obj:`mmcv.ConfigDict`|dict): Training config of + transformer head. + test_cfg (obj:`mmcv.ConfigDict`|dict): Testing config of + transformer head. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + _version = 2 + + def __init__(self, + num_classes, + in_channels, + num_query=100, + num_reg_fcs=2, + transformer=None, + sync_cls_avg_factor=False, + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + normalize=True), + code_weights=None, + bbox_coder=None, + loss_cls=dict( + type='CrossEntropyLoss', + bg_cls_weight=0.1, + use_sigmoid=False, + loss_weight=1.0, + class_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=5.0), + loss_iou=dict(type='GIoULoss', loss_weight=2.0), + train_cfg=dict( + assigner=dict( + type='HungarianAssigner', + cls_cost=dict(type='ClassificationCost', weight=1.), + reg_cost=dict(type='BBoxL1Cost', weight=5.0), + iou_cost=dict( + type='IoUCost', iou_mode='giou', weight=2.0))), + test_cfg=dict(max_per_img=100), + with_position=True, + with_multiview=False, + depth_step=0.8, + depth_num=64, + LID=False, + depth_start=1, + position_range=[-65, -65, -8.0, 65, 65, 8.0], + init_cfg=None, + normedlinear=False, + **kwargs): + # NOTE here use `AnchorFreeHead` instead of `TransformerHead`, + # since it brings inconvenience when the initialization of + # `AnchorFreeHead` is called. + if 'code_size' in kwargs: + self.code_size = kwargs['code_size'] + else: + self.code_size = 10 + if code_weights is not None: + self.code_weights = code_weights + else: + self.code_weights = [ + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2 + ] + self.code_weights = self.code_weights[:self.code_size] + self.bg_cls_weight = 0 + self.sync_cls_avg_factor = sync_cls_avg_factor + class_weight = loss_cls.get('class_weight', None) + if class_weight is not None and (self.__class__ is PETRHead): + assert isinstance(class_weight, float), 'Expected ' \ + 'class_weight to have type float. Found ' \ + f'{type(class_weight)}.' + # NOTE following the official DETR rep0, bg_cls_weight means + # relative classification weight of the no-object class. + bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight) + assert isinstance(bg_cls_weight, float), 'Expected ' \ + 'bg_cls_weight to have type float. Found ' \ + f'{type(bg_cls_weight)}.' + class_weight = torch.ones(num_classes + 1) * class_weight + # set background class as the last indice + class_weight[num_classes] = bg_cls_weight + loss_cls.update({'class_weight': class_weight}) + if 'bg_cls_weight' in loss_cls: + loss_cls.pop('bg_cls_weight') + self.bg_cls_weight = bg_cls_weight + + if train_cfg: + assert 'assigner' in train_cfg, 'assigner should be provided '\ + 'when train_cfg is set.' + assigner = train_cfg['assigner'] + assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \ + 'The classification weight for loss and matcher should be' \ + 'exactly the same.' + assert loss_bbox['loss_weight'] == assigner['reg_cost'][ + 'weight'], 'The regression L1 weight for loss and matcher ' \ + 'should be exactly the same.' + # assert loss_iou['loss_weight'] == assigner['iou_cost'][ + # 'weight'], \ + # 'The regression iou weight for loss and matcher should be' \ + # 'exactly the same.' + self.assigner = TASK_UTILS.build(assigner) + # DETR sampling=False, so use PseudoSampler + sampler_cfg = dict(type='PseudoSampler') + self.sampler = TASK_UTILS.build(sampler_cfg) + + self.num_query = num_query + self.num_classes = num_classes + self.in_channels = in_channels + self.num_reg_fcs = num_reg_fcs + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.fp16_enabled = False + self.embed_dims = 256 + self.depth_step = depth_step + self.depth_num = depth_num + self.position_dim = 3 * self.depth_num + self.position_range = position_range + self.LID = LID + self.depth_start = depth_start + self.position_level = 0 + self.with_position = with_position + self.with_multiview = with_multiview + assert 'num_feats' in positional_encoding + num_feats = positional_encoding['num_feats'] + assert num_feats * 2 == self.embed_dims, 'embed_dims should' \ + f' be exactly 2 times of num_feats. Found {self.embed_dims}' \ + f' and {num_feats}.' + self.act_cfg = transformer.get('act_cfg', + dict(type='ReLU', inplace=True)) + self.num_pred = 6 + self.normedlinear = normedlinear + super(PETRHead, self).__init__( + num_classes=num_classes, + in_channels=in_channels, + loss_cls=loss_cls, + loss_bbox=loss_bbox, + bbox_coder=bbox_coder, + init_cfg=init_cfg) + + self.loss_cls = MODELS.build(loss_cls) + self.loss_bbox = MODELS.build(loss_bbox) + self.loss_iou = MODELS.build(loss_iou) + + if self.loss_cls.use_sigmoid: + self.cls_out_channels = num_classes + else: + self.cls_out_channels = num_classes + 1 + # self.activate = build_activation_layer(self.act_cfg) + # if self.with_multiview or not self.with_position: + # self.positional_encoding = build_positional_encoding( + # positional_encoding) + self.positional_encoding = TASK_UTILS.build(positional_encoding) + self.transformer = MODELS.build(transformer) + self.code_weights = nn.Parameter( + torch.tensor(self.code_weights, requires_grad=False), + requires_grad=False) + self.bbox_coder = TASK_UTILS.build(bbox_coder) + self.pc_range = self.bbox_coder.pc_range + self._init_layers() + + def _init_layers(self): + """Initialize layers of the transformer head.""" + if self.with_position: + self.input_proj = Conv2d( + self.in_channels, self.embed_dims, kernel_size=1) + else: + self.input_proj = Conv2d( + self.in_channels, self.embed_dims, kernel_size=1) + + cls_branch = [] + for _ in range(self.num_reg_fcs): + cls_branch.append(Linear(self.embed_dims, self.embed_dims)) + cls_branch.append(nn.LayerNorm(self.embed_dims)) + cls_branch.append(nn.ReLU(inplace=True)) + if self.normedlinear: + cls_branch.append( + NormedLinear(self.embed_dims, self.cls_out_channels)) + else: + cls_branch.append(Linear(self.embed_dims, self.cls_out_channels)) + fc_cls = nn.Sequential(*cls_branch) + + reg_branch = [] + for _ in range(self.num_reg_fcs): + reg_branch.append(Linear(self.embed_dims, self.embed_dims)) + reg_branch.append(nn.ReLU()) + reg_branch.append(Linear(self.embed_dims, self.code_size)) + reg_branch = nn.Sequential(*reg_branch) + + self.cls_branches = nn.ModuleList( + [fc_cls for _ in range(self.num_pred)]) + self.reg_branches = nn.ModuleList( + [reg_branch for _ in range(self.num_pred)]) + + if self.with_multiview: + self.adapt_pos3d = nn.Sequential( + nn.Conv2d( + self.embed_dims * 3 // 2, + self.embed_dims * 4, + kernel_size=1, + stride=1, + padding=0), + nn.ReLU(), + nn.Conv2d( + self.embed_dims * 4, + self.embed_dims, + kernel_size=1, + stride=1, + padding=0), + ) + else: + self.adapt_pos3d = nn.Sequential( + nn.Conv2d( + self.embed_dims, + self.embed_dims, + kernel_size=1, + stride=1, + padding=0), + nn.ReLU(), + nn.Conv2d( + self.embed_dims, + self.embed_dims, + kernel_size=1, + stride=1, + padding=0), + ) + + if self.with_position: + self.position_encoder = nn.Sequential( + nn.Conv2d( + self.position_dim, + self.embed_dims * 4, + kernel_size=1, + stride=1, + padding=0), + nn.ReLU(), + nn.Conv2d( + self.embed_dims * 4, + self.embed_dims, + kernel_size=1, + stride=1, + padding=0), + ) + + self.reference_points = nn.Embedding(self.num_query, 3) + self.query_embedding = nn.Sequential( + nn.Linear(self.embed_dims * 3 // 2, self.embed_dims), + nn.ReLU(), + nn.Linear(self.embed_dims, self.embed_dims), + ) + + def init_weights(self): + """Initialize weights of the transformer head.""" + # The initialization for transformer is important + self.transformer.init_weights() + nn.init.uniform_(self.reference_points.weight.data, 0, 1) + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + for m in self.cls_branches: + nn.init.constant_(m[-1].bias, bias_init) + + def position_embeding(self, img_feats, img_metas, masks=None): + eps = 1e-5 + pad_h, pad_w = img_metas[0]['pad_shape'] + B, N, C, H, W = img_feats[self.position_level].shape + coords_h = torch.arange( + H, device=img_feats[0].device).float() * pad_h / H + coords_w = torch.arange( + W, device=img_feats[0].device).float() * pad_w / W + + if self.LID: + index = torch.arange( + start=0, + end=self.depth_num, + step=1, + device=img_feats[0].device).float() + index_1 = index + 1 + bin_size = (self.position_range[3] - self.depth_start) / ( + self.depth_num * (1 + self.depth_num)) + coords_d = self.depth_start + bin_size * index * index_1 + else: + index = torch.arange( + start=0, + end=self.depth_num, + step=1, + device=img_feats[0].device).float() + bin_size = (self.position_range[3] - + self.depth_start) / self.depth_num + coords_d = self.depth_start + bin_size * index + + D = coords_d.shape[0] + coords = torch.stack(torch.meshgrid([coords_w, coords_h, coords_d + ])).permute(1, 2, 3, + 0) # W, H, D, 3 + coords = torch.cat((coords, torch.ones_like(coords[..., :1])), -1) + coords[..., :2] = coords[..., :2] * torch.maximum( + coords[..., 2:3], + torch.ones_like(coords[..., 2:3]) * eps) + + img2lidars = [] + for img_meta in img_metas: + img2lidar = [] + for i in range(len(img_meta['lidar2img'])): + img2lidar.append(np.linalg.inv(img_meta['lidar2img'][i])) + img2lidars.append(np.asarray(img2lidar)) + img2lidars = np.asarray(img2lidars) + img2lidars = coords.new_tensor(img2lidars) # (B, N, 4, 4) + + coords = coords.view(1, 1, W, H, D, 4, 1).repeat(B, N, 1, 1, 1, 1, 1) + img2lidars = img2lidars.view(B, N, 1, 1, 1, 4, + 4).repeat(1, 1, W, H, D, 1, 1) + coords3d = torch.matmul(img2lidars, coords).squeeze(-1)[..., :3] + coords3d[..., 0:1] = (coords3d[..., 0:1] - self.position_range[0]) / ( + self.position_range[3] - self.position_range[0]) + coords3d[..., 1:2] = (coords3d[..., 1:2] - self.position_range[1]) / ( + self.position_range[4] - self.position_range[1]) + coords3d[..., 2:3] = (coords3d[..., 2:3] - self.position_range[2]) / ( + self.position_range[5] - self.position_range[2]) + + coords_mask = (coords3d > 1.0) | (coords3d < 0.0) + coords_mask = coords_mask.flatten(-2).sum(-1) > (D * 0.5) + coords_mask = masks | coords_mask.permute(0, 1, 3, 2) + coords3d = coords3d.permute(0, 1, 4, 5, 3, + 2).contiguous().view(B * N, -1, H, W) + coords3d = inverse_sigmoid(coords3d) + coords_position_embeding = self.position_encoder(coords3d) + + return coords_position_embeding.view(B, N, self.embed_dims, H, + W), coords_mask + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """load checkpoints.""" + # NOTE here use `AnchorFreeHead` instead of `TransformerHead`, + # since `AnchorFreeHead._load_from_state_dict` should not be + # called here. Invoking the default `Module._load_from_state_dict` + # is enough. + + # Names of some parameters in has been changed. + version = local_metadata.get('version', None) + if (version is None or version < 2) and self.__class__ is PETRHead: + convert_dict = { + '.self_attn.': '.attentions.0.', + # '.ffn.': '.ffns.0.', + '.multihead_attn.': '.attentions.1.', + '.decoder.norm.': '.decoder.post_norm.' + } + state_dict_keys = list(state_dict.keys()) + for k in state_dict_keys: + for ori_key, convert_key in convert_dict.items(): + if ori_key in k: + convert_key = k.replace(ori_key, convert_key) + state_dict[convert_key] = state_dict[k] + del state_dict[k] + + super(AnchorFreeHead, + self)._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, + unexpected_keys, error_msgs) + + def forward(self, mlvl_feats, img_metas): + """Forward function. + + Args: + mlvl_feats (tuple[Tensor]): Features from the upstream + network, each is a 5D-tensor with shape + (B, N, C, H, W). + Returns: + all_cls_scores (Tensor): Outputs from the classification head, \ + shape [nb_dec, bs, num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + all_bbox_preds (Tensor): Sigmoid outputs from the regression \ + head with normalized coordinate format \ + (cx, cy, w, l, cz, h, theta, vx, vy). \ + Shape [nb_dec, bs, num_query, 9]. + """ + + x = mlvl_feats[0] + batch_size, num_cams = x.size(0), x.size(1) + input_img_h, input_img_w = img_metas[0]['pad_shape'] + masks = x.new_ones((batch_size, num_cams, input_img_h, input_img_w)) + for img_id in range(batch_size): + for cam_id in range(num_cams): + img_h, img_w, _ = img_metas[img_id]['img_shape'][cam_id] + masks[img_id, cam_id, :img_h, :img_w] = 0 + x = self.input_proj(x.flatten(0, 1)) + x = x.view(batch_size, num_cams, *x.shape[-3:]) + # interpolate masks to have the same spatial shape with x + masks = F.interpolate(masks, size=x.shape[-2:]).to(torch.bool) + + if self.with_position: + coords_position_embeding, _ = self.position_embeding( + mlvl_feats, img_metas, masks) + pos_embed = coords_position_embeding + if self.with_multiview: + sin_embed = self.positional_encoding(masks) + sin_embed = self.adapt_pos3d(sin_embed.flatten(0, 1)).view( + x.size()) + pos_embed = pos_embed + sin_embed + else: + pos_embeds = [] + for i in range(num_cams): + xy_embed = self.positional_encoding(masks[:, i, :, :]) + pos_embeds.append(xy_embed.unsqueeze(1)) + sin_embed = torch.cat(pos_embeds, 1) + sin_embed = self.adapt_pos3d(sin_embed.flatten(0, 1)).view( + x.size()) + pos_embed = pos_embed + sin_embed + else: + if self.with_multiview: + pos_embed = self.positional_encoding(masks) + pos_embed = self.adapt_pos3d(pos_embed.flatten(0, 1)).view( + x.size()) + else: + pos_embeds = [] + for i in range(num_cams): + pos_embed = self.positional_encoding(masks[:, i, :, :]) + pos_embeds.append(pos_embed.unsqueeze(1)) + pos_embed = torch.cat(pos_embeds, 1) + + reference_points = self.reference_points.weight + query_embeds = self.query_embedding(pos2posemb3d(reference_points)) + reference_points = reference_points.unsqueeze(0).repeat( + batch_size, 1, 1) # .sigmoid() + + outs_dec, _ = self.transformer(x, masks, query_embeds, pos_embed, + self.reg_branches) + outs_dec = torch.nan_to_num(outs_dec) + outputs_classes = [] + outputs_coords = [] + for lvl in range(outs_dec.shape[0]): + reference = inverse_sigmoid(reference_points.clone()) + assert reference.shape[-1] == 3 + outputs_class = self.cls_branches[lvl](outs_dec[lvl]).to( + torch.float32) + tmp = self.reg_branches[lvl](outs_dec[lvl]).to(torch.float32) + + tmp[..., 0:2] += reference[..., 0:2] + tmp[..., 0:2] = tmp[..., 0:2].sigmoid() + tmp[..., 4:5] += reference[..., 2:3] + tmp[..., 4:5] = tmp[..., 4:5].sigmoid() + + outputs_coord = tmp + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + all_cls_scores = torch.stack(outputs_classes) + all_bbox_preds = torch.stack(outputs_coords) + + all_bbox_preds[..., 0:1] = ( + all_bbox_preds[..., 0:1] * (self.pc_range[3] - self.pc_range[0]) + + self.pc_range[0]) + all_bbox_preds[..., 1:2] = ( + all_bbox_preds[..., 1:2] * (self.pc_range[4] - self.pc_range[1]) + + self.pc_range[1]) + all_bbox_preds[..., 4:5] = ( + all_bbox_preds[..., 4:5] * (self.pc_range[5] - self.pc_range[2]) + + self.pc_range[2]) + + outs = { + 'all_cls_scores': all_cls_scores, + 'all_bbox_preds': all_bbox_preds, + 'enc_cls_scores': None, + 'enc_bbox_preds': None, + } + return outs + + def _get_target_single(self, + cls_score, + bbox_pred, + gt_labels, + gt_bboxes, + gt_bboxes_ignore=None): + """"Compute regression and classification targets for one image. + Outputs from a single decoder layer of a single feature level are used. + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_query, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_query, 4]. + gt_bboxes (Tensor): Ground truth bboxes for one image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (Tensor): Ground truth class indices for one image + with shape (num_gts, ). + gt_bboxes_ignore (Tensor, optional): Bounding boxes + which can be ignored. Default None. + Returns: + tuple[Tensor]: a tuple containing the following for one image. + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + + num_bboxes = bbox_pred.size(0) + # assigner and sampler + assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, + gt_labels, gt_bboxes_ignore) + pred_instance_3d = InstanceData(priors=bbox_pred) + gt_instances_3d = InstanceData(bboxes_3d=gt_bboxes) + sampling_result = self.sampler.sample(assign_result, pred_instance_3d, + gt_instances_3d) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label targets + labels = gt_bboxes.new_full((num_bboxes, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_bboxes.new_ones(num_bboxes) + + # bbox targets + code_size = gt_bboxes.size(1) + bbox_targets = torch.zeros_like(bbox_pred)[..., :code_size] + bbox_weights = torch.zeros_like(bbox_pred) + bbox_weights[pos_inds] = 1.0 + # DETR + bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes + return (labels, label_weights, bbox_targets, bbox_weights, pos_inds, + neg_inds) + + def get_targets(self, + cls_scores_list, + bbox_preds_list, + gt_bboxes_list, + gt_labels_list, + gt_bboxes_ignore_list=None): + """"Compute regression and classification targets for a batch image. + Outputs from a single decoder layer of a single feature level are used. + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image with shape [num_query, + cls_out_channels]. + bbox_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (cx, cy, w, h) and shape [num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + Returns: + tuple: a tuple containing the following targets. + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all \ + images. + - bbox_targets_list (list[Tensor]): BBox targets for all \ + images. + - bbox_weights_list (list[Tensor]): BBox weights for all \ + images. + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + """ + assert gt_bboxes_ignore_list is None, \ + 'Only supports for gt_bboxes_ignore setting to None.' + num_imgs = len(cls_scores_list) + gt_bboxes_ignore_list = [ + gt_bboxes_ignore_list for _ in range(num_imgs) + ] + gt_labels_list = gt_labels_list[0] + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + pos_inds_list, + neg_inds_list) = multi_apply(self._get_target_single, cls_scores_list, + bbox_preds_list, gt_labels_list, + gt_bboxes_list, gt_bboxes_ignore_list) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, num_total_pos, num_total_neg) + + def loss_by_feat_single(self, + cls_scores, + bbox_preds, + gt_bboxes_list, + gt_labels_list, + gt_bboxes_ignore_list=None): + """"Loss function for outputs from a single decoder layer of a single + feature level. + + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images. Shape [bs, num_query, cls_out_channels]. + bbox_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape [bs, num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in + [tl_x, tl_y, br_x,loss_by_feat_single br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + Returns: + dict[str, Tensor]: A dictionary of loss components for outputs + from a single decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, + gt_bboxes_list, gt_labels_list, + gt_bboxes_ignore_list) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + labels = torch.cat(labels_list, 0) + label_weights = torch.cat(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + + bbox_weights = torch.cat(bbox_weights_list, 0) + + # classification loss + cls_scores = cls_scores.reshape(-1, self.cls_out_channels) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + # if self.sync_cls_avg_factor: + # cls_avg_factor = reduce_mean( + # cls_scores.new_tensor([cls_avg_factor])) + + cls_avg_factor = max(cls_avg_factor, 1) + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + # num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + num_total_pos = torch.clamp(num_total_pos, min=1).item() + + # regression L1 loss + bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1)) + normalized_bbox_targets = normalize_bbox(bbox_targets, self.pc_range) + isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1) + bbox_weights = bbox_weights * self.code_weights + + loss_bbox = self.loss_bbox( + bbox_preds[isnotnan, :10], + normalized_bbox_targets[isnotnan, :10], + bbox_weights[isnotnan, :10], + avg_factor=num_total_pos) + + loss_cls = torch.nan_to_num(loss_cls) + loss_bbox = torch.nan_to_num(loss_bbox) + return loss_cls, loss_bbox + + def loss_by_feat(self, + gt_bboxes_list, + gt_labels_list, + preds_dicts, + gt_bboxes_ignore=None): + """"Loss function. + Args: + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + preds_dicts: + all_cls_scores (Tensor): Classification score of all + decoder layers, has shape + [nb_dec, bs, num_query, cls_out_channels]. + all_bbox_preds (Tensor): Sigmoid regression + outputs of all decode layers. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and shape + [nb_dec, bs, num_query, 4]. + enc_cls_scores (Tensor): Classification scores of + points on encode feature map , has shape + (N, h*w, num_classes). Only be passed when as_two_stage is + True, otherwise is None. + enc_bbox_preds (Tensor): Regression results of each points + on the encode feature map, has shape (N, h*w, 4). Only be + passed when as_two_stage is True, otherwise is None. + gt_bboxes_ignore (list[Tensor], optional): Bounding boxes + which can be ignored for each image. Default None. + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert gt_bboxes_ignore is None, \ + f'{self.__class__.__name__} only supports ' \ + f'for gt_bboxes_ignore setting to None.' + + all_cls_scores = preds_dicts['all_cls_scores'] + all_bbox_preds = preds_dicts['all_bbox_preds'] + enc_cls_scores = preds_dicts['enc_cls_scores'] + enc_bbox_preds = preds_dicts['enc_bbox_preds'] + + num_dec_layers = len(all_cls_scores) + device = gt_labels_list[0].device + + gt_bboxes_list = [ + torch.cat((gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), + dim=1).to(device) for gt_bboxes in gt_bboxes_list + ] + + all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] + all_gt_labels_list = [[gt_labels_list] for _ in range(num_dec_layers)] + all_gt_bboxes_ignore_list = [ + gt_bboxes_ignore for _ in range(num_dec_layers) + ] + + losses_cls, losses_bbox = multi_apply(self.loss_by_feat_single, + all_cls_scores, all_bbox_preds, + all_gt_bboxes_list, + all_gt_labels_list, + all_gt_bboxes_ignore_list) + + loss_dict = dict() + # loss of proposal generated from encode feature map. + if enc_cls_scores is not None: + binary_labels_list = [ + torch.zeros_like(gt_labels_list[i]) + for i in range(len(all_gt_labels_list)) + ] + enc_loss_cls, enc_losses_bbox = \ + self.loss_single(enc_cls_scores, enc_bbox_preds, + gt_bboxes_list, binary_labels_list, + gt_bboxes_ignore) + loss_dict['enc_loss_cls'] = enc_loss_cls + loss_dict['enc_loss_bbox'] = enc_losses_bbox + + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_bbox'] = losses_bbox[-1] + + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i + num_dec_layer += 1 + return loss_dict + + def get_bboxes(self, preds_dicts, img_metas, rescale=False): + """Generate bboxes from bbox head predictions. + + Args: + preds_dicts (tuple[list[dict]]): Prediction results. + img_metas (list[dict]): Point cloud and image's meta info. + Returns: + list[dict]: Decoded bbox, scores and labels after nms. + """ + preds_dicts = self.bbox_coder.decode(preds_dicts) + num_samples = len(preds_dicts) + + ret_list = [] + for i in range(num_samples): + preds = preds_dicts[i] + bboxes = preds['bboxes'] + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 + bboxes = img_metas[i]['box_type_3d'](bboxes, bboxes.size(-1)) + scores = preds['scores'] + labels = preds['labels'] + ret_list.append([bboxes, scores, labels]) + return ret_list diff --git a/projects/PETR/petr/petr_transformer.py b/projects/PETR/petr/petr_transformer.py new file mode 100644 index 0000000000..dbb4cc332b --- /dev/null +++ b/projects/PETR/petr/petr_transformer.py @@ -0,0 +1,540 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR3D (https://github.com/WangYueFt/detr3d) +# Copyright (c) 2021 Wang, Yue +# ------------------------------------------------------------------------ +# Modified from mmdetection3d (https://github.com/open-mmlab/mmdetection3d) +# Copyright (c) OpenMMLab. All rights reserved. +# ------------------------------------------------------------------------ + +import warnings + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import (BaseTransformerLayer, + TransformerLayerSequence) +from mmengine.model import BaseModule +from mmengine.model.weight_init import xavier_init + +# from mmcv.utils import deprecated_api_warning +from mmdet3d.registry import MODELS, TASK_UTILS + + +@MODELS.register_module() +class PETRTransformer(BaseModule): + """Implements the DETR transformer. Following the official DETR + implementation, this module copy-paste from torch.nn.Transformer with + modifications: + + * positional encodings are passed in MultiheadAttention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers + See `paper: End-to-End Object Detection with Transformers + `_ for details. + Args: + encoder (`mmcv.ConfigDict` | Dict): Config of + TransformerEncoder. Defaults to None. + decoder ((`mmcv.ConfigDict` | Dict)): Config of + TransformerDecoder. Defaults to None + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Defaults to None. + """ + + def __init__(self, encoder=None, decoder=None, init_cfg=None, cross=False): + super(PETRTransformer, self).__init__(init_cfg=init_cfg) + if encoder is not None: + self.encoder = MODELS.build(encoder) + else: + self.encoder = None + self.decoder = MODELS.build(decoder) + self.embed_dims = self.decoder.embed_dims + self.cross = cross + + def init_weights(self): + # follow the official DETR to init parameters + for m in self.modules(): + if hasattr(m, 'weight') and m.weight.dim() > 1: + xavier_init(m, distribution='uniform') + self._is_init = True + + def forward(self, x, mask, query_embed, pos_embed, reg_branch=None): + """Forward function for `Transformer`. + Args: + x (Tensor): Input query with shape [bs, c, h, w] where + c = embed_dims. + mask (Tensor): The key_padding_mask used for encoder and decoder, + with shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, with shape + [num_query, c]. + pos_embed (Tensor): The positional encoding for encoder and + decoder, with the same shape as `x`. + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + - out_dec: Output from decoder. If return_intermediate_dec \ + is True output has shape [num_dec_layers, bs, + num_query, embed_dims], else has shape [1, bs, \ + num_query, embed_dims]. + - memory: Output results from encoder, with shape \ + [bs, embed_dims, h, w]. + """ + bs, n, c, h, w = x.shape + memory = x.permute(1, 3, 4, 0, + 2).reshape(-1, bs, + c) # [bs, n, c, h, w] -> [n*h*w, bs, c] + pos_embed = pos_embed.permute(1, 3, 4, 0, 2).reshape( + -1, bs, c) # [bs, n, c, h, w] -> [n*h*w, bs, c] + query_embed = query_embed.unsqueeze(1).repeat( + 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] + mask = mask.view(bs, -1) # [bs, n, h, w] -> [bs, n*h*w] + target = torch.zeros_like(query_embed) + + # out_dec: [num_layers, num_query, bs, dim] + out_dec = self.decoder( + query=target, + key=memory, + value=memory, + key_pos=pos_embed, + query_pos=query_embed, + key_padding_mask=mask, + reg_branch=reg_branch, + ) + out_dec = out_dec.transpose(1, 2) + memory = memory.reshape(n, h, w, bs, c).permute(3, 0, 4, 1, 2) + return out_dec, memory + + +@MODELS.register_module() +class PETRDNTransformer(BaseModule): + """Implements the DETR transformer. Following the official DETR + implementation, this module copy-paste from torch.nn.Transformer with + modifications: + + * positional encodings are passed in MultiheadAttention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers + See `paper: End-to-End Object Detection with Transformers + `_ for details. + Args: + encoder (`mmcv.ConfigDict` | Dict): Config of + TransformerEncoder. Defaults to None. + decoder ((`mmcv.ConfigDict` | Dict)): Config of + TransformerDecoder. Defaults to None + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Defaults to None. + """ + + def __init__(self, encoder=None, decoder=None, init_cfg=None, cross=False): + super(PETRDNTransformer, self).__init__(init_cfg=init_cfg) + if encoder is not None: + self.encoder = MODELS.build(encoder) + else: + self.encoder = None + self.decoder = MODELS.build(decoder) + self.embed_dims = self.decoder.embed_dims + self.cross = cross + + def init_weights(self): + # follow the official DETR to init parameters + for m in self.modules(): + if hasattr(m, 'weight') and m.weight.dim() > 1: + xavier_init(m, distribution='uniform') + self._is_init = True + + def forward(self, + x, + mask, + query_embed, + pos_embed, + attn_masks=None, + reg_branch=None): + """Forward function for `Transformer`. + Args: + x (Tensor): Input query with shape [bs, c, h, w] where + c = embed_dims. + mask (Tensor): The key_padding_mask used for encoder and decoder, + with shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, with shape + [num_query, c]. + pos_embed (Tensor): The positional encoding for encoder and + decoder, with the same shape as `x`. + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + - out_dec: Output from decoder. If return_intermediate_dec \ + is True output has shape [num_dec_layers, bs, + num_query, embed_dims], else has shape [1, bs, \ + num_query, embed_dims]. + - memory: Output results from encoder, with shape \ + [bs, embed_dims, h, w]. + """ + bs, n, c, h, w = x.shape + memory = x.permute(1, 3, 4, 0, + 2).reshape(-1, bs, + c) # [bs, n, c, h, w] -> [n*h*w, bs, c] + pos_embed = pos_embed.permute(1, 3, 4, 0, 2).reshape( + -1, bs, c) # [bs, n, c, h, w] -> [n*h*w, bs, c] + query_embed = query_embed.transpose( + 0, 1) # [num_query, dim] -> [num_query, bs, dim] + mask = mask.view(bs, -1) # [bs, n, h, w] -> [bs, n*h*w] + target = torch.zeros_like(query_embed) + # out_dec: [num_layers, num_query, bs, dim] + out_dec = self.decoder( + query=target, + key=memory, + value=memory, + key_pos=pos_embed, + query_pos=query_embed, + key_padding_mask=mask, + attn_masks=[attn_masks, None], + reg_branch=reg_branch, + ) + out_dec = out_dec.transpose(1, 2) + memory = memory.reshape(n, h, w, bs, c).permute(3, 0, 4, 1, 2) + return out_dec, memory + + +@MODELS.register_module() +class PETRTransformerDecoderLayer(BaseTransformerLayer): + """Implements decoder layer in DETR transformer. + + Args: + attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): + Configs for self_attention or cross_attention, the order + should be consistent with it in `operation_order`. If it is + a dict, it would be expand to the number of attention in + `operation_order`. + feedforward_channels (int): The hidden dimension for FFNs. + ffn_dropout (float): Probability of an element to be zeroed + in ffn. Default 0.0. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). + Default:None + act_cfg (dict): The activation config for FFNs. Default: `LN` + norm_cfg (dict): Config dict for normalization layer. + Default: `LN`. + ffn_num_fcs (int): The number of fully-connected layers in FFNs. + Default:2. + """ + + def __init__(self, + attn_cfgs, + feedforward_channels, + ffn_dropout=0.0, + operation_order=None, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + ffn_num_fcs=2, + with_cp=True, + **kwargs): + super(PETRTransformerDecoderLayer, self).__init__( + attn_cfgs=attn_cfgs, + feedforward_channels=feedforward_channels, + ffn_dropout=ffn_dropout, + operation_order=operation_order, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + ffn_num_fcs=ffn_num_fcs, + **kwargs) + assert len(operation_order) == 6 + assert set(operation_order) == set( + ['self_attn', 'norm', 'cross_attn', 'ffn']) + self.use_checkpoint = with_cp + + def _forward( + self, + query, + key=None, + value=None, + query_pos=None, + key_pos=None, + attn_masks=None, + query_key_padding_mask=None, + key_padding_mask=None, + ): + """Forward function for `TransformerCoder`. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + x = super(PETRTransformerDecoderLayer, self).forward( + query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + attn_masks=attn_masks, + query_key_padding_mask=query_key_padding_mask, + key_padding_mask=key_padding_mask, + ) + + return x + + def forward(self, + query, + key=None, + value=None, + query_pos=None, + key_pos=None, + attn_masks=None, + query_key_padding_mask=None, + key_padding_mask=None, + **kwargs): + """Forward function for `TransformerCoder`. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + + if self.use_checkpoint and self.training: + x = cp.checkpoint( + self._forward, + query, + key, + value, + query_pos, + key_pos, + attn_masks, + query_key_padding_mask, + key_padding_mask, + ) + else: + x = self._forward( + query, + key=key, + value=value, + query_pos=query_pos, + key_pos=key_pos, + attn_masks=attn_masks, + query_key_padding_mask=query_key_padding_mask, + key_padding_mask=key_padding_mask) + return x + + +@MODELS.register_module() +class PETRMultiheadAttention(BaseModule): + """A wrapper for ``torch.nn.MultiheadAttention``. + + This module implements MultiheadAttention with identity connection, + and positional encoding is also passed as input. + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. + attn_drop (float): A Dropout layer on attn_output_weights. + Default: 0.0. + proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. + Default: 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): When it is True, Key, Query and Value are shape of + (batch, n, embed_dim), otherwise (n, batch, embed_dim). + Default to False. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=dict(type='Dropout', drop_prob=0.), + init_cfg=None, + batch_first=False, + **kwargs): + super(PETRMultiheadAttention, self).__init__(init_cfg) + if 'dropout' in kwargs: + warnings.warn( + 'The arguments `dropout` in MultiheadAttention ' + 'has been deprecated, now you can separately ' + 'set `attn_drop`(float), proj_drop(float), ' + 'and `dropout_layer`(dict) ', DeprecationWarning) + attn_drop = kwargs['dropout'] + dropout_layer['drop_prob'] = kwargs.pop('dropout') + + self.embed_dims = embed_dims + self.num_heads = num_heads + self.batch_first = batch_first + + self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, + **kwargs) + + self.proj_drop = nn.Dropout(proj_drop) + self.dropout_layer = MODELS.build( + dropout_layer) if dropout_layer else nn.Identity() + + # @deprecated_api_warning({'residual': 'identity'}, + # cls_name='MultiheadAttention') + def forward(self, + query, + key=None, + value=None, + identity=None, + query_pos=None, + key_pos=None, + attn_mask=None, + key_padding_mask=None, + **kwargs): + """Forward function for `MultiheadAttention`. + + **kwargs allow passing a more general data flow when combining + with other operations in `transformerlayer`. + Args: + query (Tensor): The input query with shape [num_queries, bs, + embed_dims] if self.batch_first is False, else + [bs, num_queries embed_dims]. + key (Tensor): The key tensor with shape [num_keys, bs, + embed_dims] if self.batch_first is False, else + [bs, num_keys, embed_dims] . + If None, the ``query`` will be used. Defaults to None. + value (Tensor): The value tensor with same shape as `key`. + Same in `nn.MultiheadAttention.forward`. Defaults to None. + If None, the `key` will be used. + identity (Tensor): This tensor, with the same shape as x, + will be used for the identity link. + If None, `x` will be used. Defaults to None. + query_pos (Tensor): The positional encoding for query, with + the same shape as `x`. If not None, it will + be added to `x` before forward function. Defaults to None. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. Defaults to None. If not None, it will + be added to `key` before forward function. If None, and + `query_pos` has the same shape as `key`, then `query_pos` + will be used for `key_pos`. Defaults to None. + attn_mask (Tensor): ByteTensor mask with shape [num_queries, + num_keys]. Same in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. + Defaults to None. + Returns: + Tensor: forwarded results with shape + [num_queries, bs, embed_dims] + if self.batch_first is False, else + [bs, num_queries embed_dims]. + """ + + if key is None: + key = query + if value is None: + value = key + if identity is None: + identity = query + if key_pos is None: + if query_pos is not None: + # use query_pos if key_pos is not available + if query_pos.shape == key.shape: + key_pos = query_pos + else: + warnings.warn(f'position encoding of key is' + f'missing in {self.__class__.__name__}.') + if query_pos is not None: + query = query + query_pos + if key_pos is not None: + key = key + key_pos + + # Because the dataflow('key', 'query', 'value') of + # ``torch.nn.MultiheadAttention`` is (num_query, batch, + # embed_dims), We should adjust the shape of dataflow from + # batch_first (batch, num_query, embed_dims) to num_query_first + # (num_query ,batch, embed_dims), and recover ``attn_output`` + # from num_query_first to batch_first. + if self.batch_first: + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + out = self.attn( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask)[0] + + if self.batch_first: + out = out.transpose(0, 1) + + return identity + self.dropout_layer(self.proj_drop(out)) + + +@MODELS.register_module() +class PETRTransformerEncoder(TransformerLayerSequence): + """TransformerEncoder of DETR. + + Args: + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. Only used when `self.pre_norm` is `True` + """ + + def __init__(self, *args, post_norm_cfg=dict(type='LN'), **kwargs): + super(PETRTransformerEncoder, self).__init__(*args, **kwargs) + if post_norm_cfg is not None: + self.post_norm = TASK_UTILS.build( + post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None + else: + assert not self.pre_norm, f'Use prenorm in ' \ + f'{self.__class__.__name__},' \ + f'Please specify post_norm_cfg' + self.post_norm = None + + def forward(self, *args, **kwargs): + """Forward function for `TransformerCoder`. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + x = super(PETRTransformerEncoder, self).forward(*args, **kwargs) + if self.post_norm is not None: + x = self.post_norm(x) + return x + + +@MODELS.register_module() +class PETRTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, + *args, + post_norm_cfg=dict(type='LN'), + return_intermediate=False, + **kwargs): + + super(PETRTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + if post_norm_cfg is not None: + self.post_norm = build_norm_layer(post_norm_cfg, + self.embed_dims)[1] + else: + self.post_norm = None + + def forward(self, query, *args, **kwargs): + """Forward function for `TransformerDecoder`. + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + if not self.return_intermediate: + x = super().forward(query, *args, **kwargs) + if self.post_norm: + x = self.post_norm(x)[None] + return x + + intermediate = [] + for layer in self.layers: + query = layer(query, *args, **kwargs) + if self.return_intermediate: + if self.post_norm is not None: + intermediate.append(self.post_norm(query)) + else: + intermediate.append(query) + return torch.stack(intermediate) diff --git a/projects/PETR/petr/positional_encoding.py b/projects/PETR/petr/positional_encoding.py new file mode 100644 index 0000000000..2fb0a007aa --- /dev/null +++ b/projects/PETR/petr/positional_encoding.py @@ -0,0 +1,171 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from mmdetection (https://github.com/open-mmlab/mmdetection) +# Copyright (c) OpenMMLab. All rights reserved. +# ------------------------------------------------------------------------ +import math + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmdet3d.registry import MODELS, TASK_UTILS + + +@TASK_UTILS.register_module() +class SinePositionalEncoding3D(BaseModule): + """Position encoding with sine and cosine functions. See `End-to-End Object + Detection with Transformers. + + `_ for details. + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. Note the final returned dimension + for each position is 2 times of this value. + temperature (int, optional): The temperature used for scaling + the position embedding. Defaults to 10000. + normalize (bool, optional): Whether to normalize the position + embedding. Defaults to False. + scale (float, optional): A scale factor that scales the position + embedding. The scale will be used only when `normalize` is True. + Defaults to 2*pi. + eps (float, optional): A value added to the denominator for + numerical stability. Defaults to 1e-6. + offset (float): offset add to embed when do the normalization. + Defaults to 0. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + """ + + def __init__(self, + num_feats, + temperature=10000, + normalize=False, + scale=2 * math.pi, + eps=1e-6, + offset=0., + init_cfg=None): + super(SinePositionalEncoding3D, self).__init__(init_cfg) + if normalize: + assert isinstance(scale, (float, int)), 'when normalize is set,' \ + 'scale should be provided and in float or int type, ' \ + f'found {type(scale)}' + self.num_feats = num_feats + self.temperature = temperature + self.normalize = normalize + self.scale = scale + self.eps = eps + self.offset = offset + + def forward(self, mask): + """Forward function for `SinePositionalEncoding`. + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + # For convenience of exporting to ONNX, it's required to convert + # `masks` from bool to int. + mask = mask.to(torch.int) + not_mask = 1 - mask # logical_not + n_embed = not_mask.cumsum(1, dtype=torch.float32) + y_embed = not_mask.cumsum(2, dtype=torch.float32) + x_embed = not_mask.cumsum(3, dtype=torch.float32) + if self.normalize: + n_embed = (n_embed + self.offset) / \ + (n_embed[:, -1:, :, :] + self.eps) * self.scale + y_embed = (y_embed + self.offset) / \ + (y_embed[:, :, -1:, :] + self.eps) * self.scale + x_embed = (x_embed + self.offset) / \ + (x_embed[:, :, :, -1:] + self.eps) * self.scale + dim_t = torch.arange( + self.num_feats, dtype=torch.float32, device=mask.device) + dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats) + pos_n = n_embed[:, :, :, :, None] / dim_t + pos_x = x_embed[:, :, :, :, None] / dim_t + pos_y = y_embed[:, :, :, :, None] / dim_t + # use `view` instead of `flatten` for dynamically exporting to ONNX + B, N, H, W = mask.size() + pos_n = torch.stack( + (pos_n[:, :, :, :, 0::2].sin(), pos_n[:, :, :, :, 1::2].cos()), + dim=4).view(B, N, H, W, -1) + pos_x = torch.stack( + (pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), + dim=4).view(B, N, H, W, -1) + pos_y = torch.stack( + (pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), + dim=4).view(B, N, H, W, -1) + pos = torch.cat((pos_n, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(num_feats={self.num_feats}, ' + repr_str += f'temperature={self.temperature}, ' + repr_str += f'normalize={self.normalize}, ' + repr_str += f'scale={self.scale}, ' + repr_str += f'eps={self.eps})' + return repr_str + + +@MODELS.register_module() +class LearnedPositionalEncoding3D(BaseModule): + """Position embedding with learnable embedding weights. + + Args: + num_feats (int): The feature dimension for each position + along x-axis or y-axis. The final returned dimension for + each position is 2 times of this value. + row_num_embed (int, optional): The dictionary size of row embeddings. + Default 50. + col_num_embed (int, optional): The dictionary size of col embeddings. + Default 50. + init_cfg (dict or list[dict], optional): Initialization config dict. + """ + + def __init__(self, + num_feats, + row_num_embed=50, + col_num_embed=50, + init_cfg=dict(type='Uniform', layer='Embedding')): + super(LearnedPositionalEncoding3D, self).__init__(init_cfg) + self.row_embed = nn.Embedding(row_num_embed, num_feats) + self.col_embed = nn.Embedding(col_num_embed, num_feats) + self.num_feats = num_feats + self.row_num_embed = row_num_embed + self.col_num_embed = col_num_embed + + def forward(self, mask): + """Forward function for `LearnedPositionalEncoding`. + Args: + mask (Tensor): ByteTensor mask. Non-zero values representing + ignored positions, while zero values means valid positions + for this image. Shape [bs, h, w]. + Returns: + pos (Tensor): Returned position embedding with shape + [bs, num_feats*2, h, w]. + """ + h, w = mask.shape[-2:] + x = torch.arange(w, device=mask.device) + y = torch.arange(h, device=mask.device) + x_embed = self.col_embed(x) + y_embed = self.row_embed(y) + pos = torch.cat( + (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat( + 1, w, 1)), + dim=-1).permute(2, 0, + 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1) + return pos + + def __repr__(self): + """str: a string that describes the module""" + repr_str = self.__class__.__name__ + repr_str += f'(num_feats={self.num_feats}, ' + repr_str += f'row_num_embed={self.row_num_embed}, ' + repr_str += f'col_num_embed={self.col_num_embed})' + return repr_str diff --git a/projects/PETR/petr/transforms_3d.py b/projects/PETR/petr/transforms_3d.py new file mode 100644 index 0000000000..c8f998658f --- /dev/null +++ b/projects/PETR/petr/transforms_3d.py @@ -0,0 +1,207 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch +from mmcv.transforms import BaseTransform +from PIL import Image + +from mmdet3d.registry import TRANSFORMS +from mmdet3d.structures.bbox_3d import LiDARInstance3DBoxes + + +@TRANSFORMS.register_module() +class ResizeCropFlipImage(BaseTransform): + """Random resize, Crop and flip the image + Args: + size (tuple, optional): Fixed padding size. + """ + + def __init__(self, data_aug_conf=None, training=True): + self.data_aug_conf = data_aug_conf + self.training = training + + def transform(self, results): + """Call function to pad images, masks, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Updated result dict. + """ + + imgs = results['img'] + N = len(imgs) + new_imgs = [] + resize, resize_dims, crop, flip, rotate = self._sample_augmentation() + results['lidar2cam'] = np.array(results['lidar2cam']) + for i in range(N): + intrinsic = np.array(results['cam2img'][i]) + viewpad = np.eye(4) + viewpad[:intrinsic.shape[0], :intrinsic.shape[1]] = intrinsic + results['cam2img'][i] = viewpad + img = Image.fromarray(np.uint8(imgs[i])) + # augmentation (resize, crop, horizontal flip, rotate) + # different view use different aug (BEV Det) + img, ida_mat = self._img_transform( + img, + resize=resize, + resize_dims=resize_dims, + crop=crop, + flip=flip, + rotate=rotate, + ) + new_imgs.append(np.array(img).astype(np.float32)) + results['cam2img'][ + i][:3, :3] = ida_mat @ results['cam2img'][i][:3, :3] + + results['img'] = new_imgs + + return results + + def _get_rot(self, h): + + return torch.Tensor([ + [np.cos(h), np.sin(h)], + [-np.sin(h), np.cos(h)], + ]) + + def _img_transform(self, img, resize, resize_dims, crop, flip, rotate): + ida_rot = torch.eye(2) + ida_tran = torch.zeros(2) + # adjust image + img = img.resize(resize_dims) + img = img.crop(crop) + if flip: + img = img.transpose(method=Image.FLIP_LEFT_RIGHT) + img = img.rotate(rotate) + + # post-homography transformation + ida_rot *= resize + ida_tran -= torch.Tensor(crop[:2]) + if flip: + A = torch.Tensor([[-1, 0], [0, 1]]) + b = torch.Tensor([crop[2] - crop[0], 0]) + ida_rot = A.matmul(ida_rot) + ida_tran = A.matmul(ida_tran) + b + A = self._get_rot(rotate / 180 * np.pi) + b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2 + b = A.matmul(-b) + b + ida_rot = A.matmul(ida_rot) + ida_tran = A.matmul(ida_tran) + b + ida_mat = torch.eye(3) + ida_mat[:2, :2] = ida_rot + ida_mat[:2, 2] = ida_tran + return img, ida_mat + + def _sample_augmentation(self): + H, W = self.data_aug_conf['H'], self.data_aug_conf['W'] + fH, fW = self.data_aug_conf['final_dim'] + if self.training: + resize = np.random.uniform(*self.data_aug_conf['resize_lim']) + resize_dims = (int(W * resize), int(H * resize)) + newW, newH = resize_dims + crop_h = int( + (1 - np.random.uniform(*self.data_aug_conf['bot_pct_lim'])) * + newH) - fH + crop_w = int(np.random.uniform(0, max(0, newW - fW))) + crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) + flip = False + if self.data_aug_conf['rand_flip'] and np.random.choice([0, 1]): + flip = True + rotate = np.random.uniform(*self.data_aug_conf['rot_lim']) + else: + resize = max(fH / H, fW / W) + resize_dims = (int(W * resize), int(H * resize)) + newW, newH = resize_dims + crop_h = int( + (1 - np.mean(self.data_aug_conf['bot_pct_lim'])) * newH) - fH + crop_w = int(max(0, newW - fW) / 2) + crop = (crop_w, crop_h, crop_w + fW, crop_h + fH) + flip = False + rotate = 0 + return resize, resize_dims, crop, flip, rotate + + +@TRANSFORMS.register_module() +class GlobalRotScaleTransImage(BaseTransform): + """Random resize, Crop and flip the image + Args: + size (tuple, optional): Fixed padding size. + """ + + def __init__( + self, + rot_range=[-0.3925, 0.3925], + scale_ratio_range=[0.95, 1.05], + translation_std=[0, 0, 0], + reverse_angle=False, + training=True, + ): + + self.rot_range = rot_range + self.scale_ratio_range = scale_ratio_range + self.translation_std = translation_std + + self.reverse_angle = reverse_angle + self.training = training + + def transform(self, results): + """Call function to pad images, masks, semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + Returns: + dict: Updated result dict. + """ + # random rotate + rot_angle = np.random.uniform(*self.rot_range) + + self.rotate_bev_along_z(results, rot_angle) + if self.reverse_angle: + rot_angle *= -1 + results['gt_bboxes_3d'].rotate(np.array(rot_angle)) + + # random scale + scale_ratio = np.random.uniform(*self.scale_ratio_range) + self.scale_xyz(results, scale_ratio) + results['gt_bboxes_3d'].scale(scale_ratio) + + # TODO: support translation + if not self.reverse_angle: + gt_bboxes_3d = results['gt_bboxes_3d'].tensor.numpy() + gt_bboxes_3d[:, 6] -= 2 * rot_angle + results['gt_bboxes_3d'] = LiDARInstance3DBoxes( + gt_bboxes_3d, box_dim=9) + + return results + + def rotate_bev_along_z(self, results, angle): + rot_cos = torch.cos(torch.tensor(angle)) + rot_sin = torch.sin(torch.tensor(angle)) + + rot_mat = torch.tensor([[rot_cos, -rot_sin, 0, 0], + [rot_sin, rot_cos, 0, 0], [0, 0, 1, 0], + [0, 0, 0, 1]]) + rot_mat_inv = torch.inverse(rot_mat) + num_view = len(results['lidar2cam']) + for view in range(num_view): + results['lidar2cam'][view] = ( + torch.tensor(np.array(results['lidar2cam'][view])).float() + @ rot_mat_inv).numpy() + + return + + def scale_xyz(self, results, scale_ratio): + rot_mat = torch.tensor([ + [scale_ratio, 0, 0, 0], + [0, scale_ratio, 0, 0], + [0, 0, scale_ratio, 0], + [0, 0, 0, 1], + ]) + + rot_mat_inv = torch.inverse(rot_mat) + + num_view = len(results['lidar2cam']) + for view in range(num_view): + results['lidar2cam'][view] = (torch.tensor( + rot_mat_inv.T @ results['lidar2cam'][view]).float()).numpy() + return diff --git a/projects/PETR/petr/utils.py b/projects/PETR/petr/utils.py new file mode 100644 index 0000000000..6eca428bf3 --- /dev/null +++ b/projects/PETR/petr/utils.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import torch + +import mmdet3d +from mmdet3d.structures.bbox_3d.utils import limit_period + + +def normalize_bbox(bboxes, pc_range): + + cx = bboxes[..., 0:1] + cy = bboxes[..., 1:2] + cz = bboxes[..., 2:3] + w = bboxes[..., 3:4].log() + length = bboxes[..., 4:5].log() + h = bboxes[..., 5:6].log() + + rot = bboxes[..., 6:7] + if bboxes.size(-1) > 7: + vx = bboxes[..., 7:8] + vy = bboxes[..., 8:9] + normalized_bboxes = torch.cat( + (cx, cy, w, length, cz, h, rot.sin(), rot.cos(), vx, vy), dim=-1) + else: + normalized_bboxes = torch.cat( + (cx, cy, w, length, cz, h, rot.sin(), rot.cos()), dim=-1) + return normalized_bboxes + + +def denormalize_bbox(normalized_bboxes, pc_range): + # rotation + rot_sine = normalized_bboxes[..., 6:7] + + rot_cosine = normalized_bboxes[..., 7:8] + rot = torch.atan2(rot_sine, rot_cosine) + + # center in the bev + cx = normalized_bboxes[..., 0:1] + cy = normalized_bboxes[..., 1:2] + cz = normalized_bboxes[..., 4:5] + + # size + w = normalized_bboxes[..., 2:3] + length = normalized_bboxes[..., 3:4] + h = normalized_bboxes[..., 5:6] + + w = w.exp() + length = length.exp() + h = h.exp() + if normalized_bboxes.size(-1) > 8: + # velocity + vx = normalized_bboxes[:, 8:9] + vy = normalized_bboxes[:, 9:10] + denormalized_bboxes = torch.cat( + [cx, cy, cz, w, length, h, rot, vx, vy], dim=-1) + else: + denormalized_bboxes = torch.cat([cx, cy, cz, w, length, h, rot], + dim=-1) + + if int(mmdet3d.__version__[0]) >= 1: + denormalized_bboxes_clone = denormalized_bboxes.clone() + denormalized_bboxes[:, 3] = denormalized_bboxes_clone[:, 4] + denormalized_bboxes[:, 4] = denormalized_bboxes_clone[:, 3] + # change yaw + denormalized_bboxes[:, + 6] = -denormalized_bboxes_clone[:, 6] - np.pi / 2 + denormalized_bboxes[:, 6] = limit_period( + denormalized_bboxes[:, 6], period=np.pi * 2) + return denormalized_bboxes diff --git a/projects/PETR/petr/vovnetcp.py b/projects/PETR/petr/vovnetcp.py new file mode 100644 index 0000000000..62f0fdeafb --- /dev/null +++ b/projects/PETR/petr/vovnetcp.py @@ -0,0 +1,475 @@ +# ------------------------------------------------------------------------ +# Copyright (c) 2022 megvii-model. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from DETR3D (https://github.com/WangYueFt/detr3d) +# Copyright (c) 2021 Wang, Yue +# ------------------------------------------------------------------------ +# Copyright (c) Youngwan Lee (ETRI) All Rights Reserved. +# Copyright 2021 Toyota Research Institute. All rights reserved. +# ------------------------------------------------------------------------ +import warnings +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as cp +from mmengine.model import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm + +from mmdet3d.registry import MODELS + +VoVNet19_slim_dw_eSE = { + 'stem': [64, 64, 64], + 'stage_conv_ch': [64, 80, 96, 112], + 'stage_out_ch': [112, 256, 384, 512], + 'layer_per_block': 3, + 'block_per_stage': [1, 1, 1, 1], + 'eSE': True, + 'dw': True +} + +VoVNet19_dw_eSE = { + 'stem': [64, 64, 64], + 'stage_conv_ch': [128, 160, 192, 224], + 'stage_out_ch': [256, 512, 768, 1024], + 'layer_per_block': 3, + 'block_per_stage': [1, 1, 1, 1], + 'eSE': True, + 'dw': True +} + +VoVNet19_slim_eSE = { + 'stem': [64, 64, 128], + 'stage_conv_ch': [64, 80, 96, 112], + 'stage_out_ch': [112, 256, 384, 512], + 'layer_per_block': 3, + 'block_per_stage': [1, 1, 1, 1], + 'eSE': True, + 'dw': False +} + +VoVNet19_eSE = { + 'stem': [64, 64, 128], + 'stage_conv_ch': [128, 160, 192, 224], + 'stage_out_ch': [256, 512, 768, 1024], + 'layer_per_block': 3, + 'block_per_stage': [1, 1, 1, 1], + 'eSE': True, + 'dw': False +} + +VoVNet39_eSE = { + 'stem': [64, 64, 128], + 'stage_conv_ch': [128, 160, 192, 224], + 'stage_out_ch': [256, 512, 768, 1024], + 'layer_per_block': 5, + 'block_per_stage': [1, 1, 2, 2], + 'eSE': True, + 'dw': False +} + +VoVNet57_eSE = { + 'stem': [64, 64, 128], + 'stage_conv_ch': [128, 160, 192, 224], + 'stage_out_ch': [256, 512, 768, 1024], + 'layer_per_block': 5, + 'block_per_stage': [1, 1, 4, 3], + 'eSE': True, + 'dw': False +} + +VoVNet99_eSE = { + 'stem': [64, 64, 128], + 'stage_conv_ch': [128, 160, 192, 224], + 'stage_out_ch': [256, 512, 768, 1024], + 'layer_per_block': 5, + 'block_per_stage': [1, 3, 9, 3], + 'eSE': True, + 'dw': False +} + +_STAGE_SPECS = { + 'V-19-slim-dw-eSE': VoVNet19_slim_dw_eSE, + 'V-19-dw-eSE': VoVNet19_dw_eSE, + 'V-19-slim-eSE': VoVNet19_slim_eSE, + 'V-19-eSE': VoVNet19_eSE, + 'V-39-eSE': VoVNet39_eSE, + 'V-57-eSE': VoVNet57_eSE, + 'V-99-eSE': VoVNet99_eSE, +} + + +def dw_conv3x3(in_channels, + out_channels, + module_name, + postfix, + stride=1, + kernel_size=3, + padding=1): + """3x3 convolution with padding.""" + return [ + ('{}_{}/dw_conv3x3'.format(module_name, postfix), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=out_channels, + bias=False)), + ('{}_{}/pw_conv1x1'.format(module_name, postfix), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + bias=False)), + ('{}_{}/pw_norm'.format(module_name, + postfix), nn.BatchNorm2d(out_channels)), + ('{}_{}/pw_relu'.format(module_name, postfix), nn.ReLU(inplace=True)), + ] + + +def conv3x3(in_channels, + out_channels, + module_name, + postfix, + stride=1, + groups=1, + kernel_size=3, + padding=1): + """3x3 convolution with padding.""" + return [ + ( + f'{module_name}_{postfix}/conv', + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + ), + ), + (f'{module_name}_{postfix}/norm', nn.BatchNorm2d(out_channels)), + (f'{module_name}_{postfix}/relu', nn.ReLU(inplace=True)), + ] + + +def conv1x1(in_channels, + out_channels, + module_name, + postfix, + stride=1, + groups=1, + kernel_size=1, + padding=0): + """1x1 convolution with padding.""" + return [ + ( + f'{module_name}_{postfix}/conv', + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=groups, + bias=False, + ), + ), + (f'{module_name}_{postfix}/norm', nn.BatchNorm2d(out_channels)), + (f'{module_name}_{postfix}/relu', nn.ReLU(inplace=True)), + ] + + +class Hsigmoid(nn.Module): + + def __init__(self, inplace=True): + super(Hsigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return F.relu6(x + 3.0, inplace=self.inplace) / 6.0 + + +class eSEModule(nn.Module): + + def __init__(self, channel, reduction=4): + super(eSEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Conv2d(channel, channel, kernel_size=1, padding=0) + self.hsigmoid = Hsigmoid() + + def forward(self, x): + input = x + x = self.avg_pool(x) + x = self.fc(x) + x = self.hsigmoid(x) + return input * x + + +class _OSA_module(nn.Module): + + def __init__(self, + in_ch, + stage_ch, + concat_ch, + layer_per_block, + module_name, + SE=False, + identity=False, + depthwise=False, + with_cp=True): + + super(_OSA_module, self).__init__() + + self.identity = identity + self.depthwise = depthwise + self.isReduced = False + self.use_checkpoint = with_cp + self.layers = nn.ModuleList() + in_channel = in_ch + if self.depthwise and in_channel != stage_ch: + self.isReduced = True + self.conv_reduction = nn.Sequential( + OrderedDict( + conv1x1(in_channel, stage_ch, + '{}_reduction'.format(module_name), '0'))) + for i in range(layer_per_block): + if self.depthwise: + self.layers.append( + nn.Sequential( + OrderedDict( + dw_conv3x3(stage_ch, stage_ch, module_name, i)))) + else: + self.layers.append( + nn.Sequential( + OrderedDict( + conv3x3(in_channel, stage_ch, module_name, i)))) + in_channel = stage_ch + + # feature aggregation + in_channel = in_ch + layer_per_block * stage_ch + self.concat = nn.Sequential( + OrderedDict(conv1x1(in_channel, concat_ch, module_name, 'concat'))) + + self.ese = eSEModule(concat_ch) + + def _forward(self, x): + + identity_feat = x + + output = [] + output.append(x) + if self.depthwise and self.isReduced: + x = self.conv_reduction(x) + for layer in self.layers: + x = layer(x) + output.append(x) + + x = torch.cat(output, dim=1) + xt = self.concat(x) + + xt = self.ese(xt) + + if self.identity: + xt = xt + identity_feat + + return xt + + def forward(self, x): + + if self.use_checkpoint and self.training: + xt = cp.checkpoint(self._forward, x) + else: + xt = self._forward(x) + + return xt + + +class _OSA_stage(nn.Sequential): + + def __init__(self, + in_ch, + stage_ch, + concat_ch, + block_per_stage, + layer_per_block, + stage_num, + SE=False, + depthwise=False): + + super(_OSA_stage, self).__init__() + + if not stage_num == 2: + self.add_module( + 'Pooling', + nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)) + + if block_per_stage != 1: + SE = False + module_name = f'OSA{stage_num}_1' + self.add_module( + module_name, + _OSA_module( + in_ch, + stage_ch, + concat_ch, + layer_per_block, + module_name, + SE, + depthwise=depthwise)) + for i in range(block_per_stage - 1): + if i != block_per_stage - 2: # last block + SE = False + module_name = f'OSA{stage_num}_{i + 2}' + self.add_module( + module_name, + _OSA_module( + concat_ch, + stage_ch, + concat_ch, + layer_per_block, + module_name, + SE, + identity=True, + depthwise=depthwise), + ) + + +@MODELS.register_module() +class VoVNetCP(BaseModule): + + def __init__(self, + spec_name, + input_ch=3, + out_features=None, + frozen_stages=-1, + norm_eval=True, + pretrained=None, + init_cfg=None): + """ + Args: + input_ch(int) : the number of input channel + out_features (list[str]): name of the layers whose outputs should + be returned in forward. Can be anything in "stem", "stage2" ... + """ + super(VoVNetCP, self).__init__(init_cfg) + self.frozen_stages = frozen_stages + self.norm_eval = norm_eval + + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + stage_specs = _STAGE_SPECS[spec_name] + + stem_ch = stage_specs['stem'] + config_stage_ch = stage_specs['stage_conv_ch'] + config_concat_ch = stage_specs['stage_out_ch'] + block_per_stage = stage_specs['block_per_stage'] + layer_per_block = stage_specs['layer_per_block'] + SE = stage_specs['eSE'] + depthwise = stage_specs['dw'] + + self._out_features = out_features + + # Stem module + conv_type = dw_conv3x3 if depthwise else conv3x3 + stem = conv3x3(input_ch, stem_ch[0], 'stem', '1', 2) + stem += conv_type(stem_ch[0], stem_ch[1], 'stem', '2', 1) + stem += conv_type(stem_ch[1], stem_ch[2], 'stem', '3', 2) + self.add_module('stem', nn.Sequential((OrderedDict(stem)))) + current_stirde = 4 + self._out_feature_strides = { + 'stem': current_stirde, + 'stage2': current_stirde + } + self._out_feature_channels = {'stem': stem_ch[2]} + + stem_out_ch = [stem_ch[2]] + in_ch_list = stem_out_ch + config_concat_ch[:-1] + # OSA stages + self.stage_names = [] + for i in range(4): # num_stages + name = 'stage%d' % (i + 2) # stage 2 ... stage 5 + self.stage_names.append(name) + self.add_module( + name, + _OSA_stage( + in_ch_list[i], + config_stage_ch[i], + config_concat_ch[i], + block_per_stage[i], + layer_per_block, + i + 2, + SE, + depthwise, + ), + ) + + self._out_feature_channels[name] = config_concat_ch[i] + if not i == 0: + self._out_feature_strides[name] = current_stirde = int( + current_stirde * 2) + + # initialize weights + # self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + + # def forward(self, x): + # outputs = {} + # x = self.stem(x) + # if "stem" in self._out_features: + # outputs["stem"] = x + # for name in self.stage_names: + # x = getattr(self, name)(x) + # if name in self._out_features: + # outputs[name] = x + + # return outputs + + def forward(self, x): + outputs = [] + x = self.stem(x) + if 'stem' in self._out_features: + outputs.append(x) + for name in self.stage_names: + x = getattr(self, name)(x) + if name in self._out_features: + outputs.append(x) + + return outputs + + def _freeze_stages(self): + if self.frozen_stages >= 0: + m = getattr(self, 'stem') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'stage{i+1}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(VoVNetCP, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval()