From c8c46b06238676791f0e020da8cc1ffc02769b4c Mon Sep 17 00:00:00 2001 From: lupeng Date: Wed, 12 Apr 2023 18:10:39 +0800 Subject: [PATCH 01/12] support petr --- .../petr/configs/_base_/coco_detection.py | 67 ++ projects/petr/configs/_base_/datasets | 1 + .../petr/configs/_base_/default_runtime.py | 41 + .../petr/configs/petr_r101_8xb4-100e_coco.py | 4 + .../petr/configs/petr_r50_8xb4-100e_coco.py | 264 +++++++ .../configs/petr_swin-l_8xb4-100e_coco.py | 24 + projects/petr/datasets/__init__.py | 3 + .../petr/datasets/bbox_keypoint_structure.py | 285 +++++++ projects/petr/datasets/coco_dataset.py | 79 ++ projects/petr/datasets/transforms.py | 169 +++++ projects/petr/models/__init__.py | 4 + projects/petr/models/losses.py | 89 +++ projects/petr/models/match_costs.py | 63 ++ projects/petr/models/petr.py | 716 ++++++++++++++++++ projects/petr/models/petr_head.py | 658 ++++++++++++++++ projects/petr/models/transformers.py | 254 +++++++ 16 files changed, 2721 insertions(+) create mode 100644 projects/petr/configs/_base_/coco_detection.py create mode 120000 projects/petr/configs/_base_/datasets create mode 100644 projects/petr/configs/_base_/default_runtime.py create mode 100644 projects/petr/configs/petr_r101_8xb4-100e_coco.py create mode 100644 projects/petr/configs/petr_r50_8xb4-100e_coco.py create mode 100644 projects/petr/configs/petr_swin-l_8xb4-100e_coco.py create mode 100644 projects/petr/datasets/__init__.py create mode 100644 projects/petr/datasets/bbox_keypoint_structure.py create mode 100644 projects/petr/datasets/coco_dataset.py create mode 100644 projects/petr/datasets/transforms.py create mode 100644 projects/petr/models/__init__.py create mode 100644 projects/petr/models/losses.py create mode 100644 projects/petr/models/match_costs.py create mode 100644 projects/petr/models/petr.py create mode 100644 projects/petr/models/petr_head.py create mode 100644 projects/petr/models/transformers.py diff --git a/projects/petr/configs/_base_/coco_detection.py b/projects/petr/configs/_base_/coco_detection.py new file mode 100644 index 0000000000..1761a0a3cb --- /dev/null +++ b/projects/petr/configs/_base_/coco_detection.py @@ -0,0 +1,67 @@ +# dataset settings +dataset_type = 'mmpose.CocoDataset' +data_mode = 'bottomup' +data_root = 'data/coco/' + +# file_client_args = dict( +# backend='petrel', +# path_mapping=dict({ +# './data/': 's3://openmmlab/datasets/detection/', +# 'data/': 's3://openmmlab/datasets/detection/' +# })) +file_client_args = dict(backend='disk') + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='PackDetInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=file_client_args), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + # If you don't have a gt annotation, delete the pipeline + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='AspectRatioBatchSampler'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/person_keypoints_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='annotations/person_keypoints_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'annotations/person_keypoints_val2017.json', +) +test_evaluator = val_evaluator + +default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) diff --git a/projects/petr/configs/_base_/datasets b/projects/petr/configs/_base_/datasets new file mode 120000 index 0000000000..8feca66d56 --- /dev/null +++ b/projects/petr/configs/_base_/datasets @@ -0,0 +1 @@ +../../../../configs/_base_/datasets \ No newline at end of file diff --git a/projects/petr/configs/_base_/default_runtime.py b/projects/petr/configs/_base_/default_runtime.py new file mode 100644 index 0000000000..7a2a84af27 --- /dev/null +++ b/projects/petr/configs/_base_/default_runtime.py @@ -0,0 +1,41 @@ +default_scope = 'mmdet' +custom_imports = dict(imports=['models', 'datasets']) + +# hooks +default_hooks = dict( + timer=dict(type='IterTimerHook'), + logger=dict(type='LoggerHook', interval=50), + param_scheduler=dict(type='ParamSchedulerHook'), + checkpoint=dict(type='CheckpointHook', interval=5, max_keep_ckpts=3), + sampler_seed=dict(type='DistSamplerSeedHook'), + visualization=dict(type='mmpose.PoseVisualizationHook', enable=False), +) + +# multi-processing backend +env_cfg = dict( + cudnn_benchmark=False, + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + dist_cfg=dict(backend='nccl'), +) + +# visualizer +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='mmpose.PoseLocalVisualizer', + vis_backends=vis_backends, + name='visualizer') + +# logger +log_processor = dict( + type='LogProcessor', window_size=50, by_epoch=True, num_digits=6) +log_level = 'INFO' +load_from = None +resume = False + +# file I/O backend +file_client_args = dict(backend='disk') + +# training/validation/testing progress +train_cfg = dict() +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') diff --git a/projects/petr/configs/petr_r101_8xb4-100e_coco.py b/projects/petr/configs/petr_r101_8xb4-100e_coco.py new file mode 100644 index 0000000000..9ac326f147 --- /dev/null +++ b/projects/petr/configs/petr_r101_8xb4-100e_coco.py @@ -0,0 +1,4 @@ +_base_ = ['petr_r50_8xb4-100e_coco.py'] + +# model +model = dict(backbone=dict(depth=101)) diff --git a/projects/petr/configs/petr_r50_8xb4-100e_coco.py b/projects/petr/configs/petr_r50_8xb4-100e_coco.py new file mode 100644 index 0000000000..36a1434344 --- /dev/null +++ b/projects/petr/configs/petr_r50_8xb4-100e_coco.py @@ -0,0 +1,264 @@ +_base_ = ['./_base_/default_runtime.py'] + +# learning policy +max_epochs = 100 +train_cfg = dict( + type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=5) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[80], + gamma=0.1) +] + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (16 GPUs) x (2 samples per GPU) +auto_scale_lr = dict(base_batch_size=32) + +# optimizer +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001), + clip_grad=dict(max_norm=0.1, norm_type=2), + paramwise_cfg=dict( + custom_keys={ + 'backbone': dict(lr_mult=0.1), + 'sampling_offsets': dict(lr_mult=0.1), + 'reference_points': dict(lr_mult=0.1) + })) + +# model +num_keypoints = 17 +checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/deformable_' \ + 'detr/deformable_detr_twostage_refine_r50_16x2_50e_coco/deformable_' \ + 'detr_twostage_refine_r50_16x2_50e_coco_20210419_220613-9d28ab72.pth' +model = dict( + type='PETR', + num_queries=300, + num_feature_levels=4, + num_keypoints=num_keypoints, + with_box_refine=True, + as_two_stage=True, + init_cfg=dict( + type='Pretrained', + checkpoint=checkpoint, + ), + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_size_divisor=1), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch'), + neck=dict( + type='ChannelMapper', + in_channels=[512, 1024, 2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=dict(type='GN', num_groups=32), + num_outs=4), + encoder=dict( # DeformableDetrTransformerEncoder + num_layers=6, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))), + decoder=dict( # PetrTransformerDecoder + num_layers=3, + num_keypoints=num_keypoints, + return_intermediate=True, + layer_cfg=dict( # PetrTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + cross_attn_cfg=dict( # MultiScaleDeformablePoseAttention + embed_dims=256, + num_points=num_keypoints, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)), + post_norm_cfg=None), + hm_encoder=dict( # DeformableDetrTransformerEncoder + num_layers=1, + layer_cfg=dict( # DeformableDetrTransformerEncoderLayer + self_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + num_levels=1, + batch_first=True), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))), + kpt_decoder=dict( # DeformableDetrTransformerDecoder + num_layers=2, + return_intermediate=True, + layer_cfg=dict( # DetrTransformerDecoderLayer + self_attn_cfg=dict( # MultiheadAttention + embed_dims=256, + num_heads=8, + dropout=0.1, + batch_first=True), + cross_attn_cfg=dict( # MultiScaleDeformableAttention + embed_dims=256, + im2col_step=128), + ffn_cfg=dict( + embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)), + post_norm_cfg=None), + positional_encoding=dict(num_feats=128, normalize=True, offset=-0.5), + bbox_head=dict( + type='PETRHead', + num_classes=1, + num_keypoints=num_keypoints, + sync_cls_avg_factor=True, + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=2.0), + loss_reg=dict(type='L1Loss', loss_weight=80.0), + loss_reg_aux=dict(type='L1Loss', loss_weight=70.0), + loss_oks=dict( + type='OksLoss', + metainfo='configs/_base_/datasets/coco.py', + loss_weight=3.0), + loss_oks_aux=dict( + type='OksLoss', + metainfo='configs/_base_/datasets/coco.py', + loss_weight=2.0), + loss_hm=dict(type='mmpose.FocalHeatmapLoss', loss_weight=4.0), + ), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='HungarianAssigner', + match_costs=[ + dict(type='FocalLossCost', weight=2.0), + dict(type='KptL1Cost', weight=70.0), + dict( + type='OksCost', + metainfo='configs/_base_/datasets/coco.py', + weight=7.0) + ])), + test_cfg=dict( + max_per_img=100, + score_thr=0.0, + )) + +train_pipeline = [ + dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), + dict(type='PoseToDetConverter'), + dict(type='PhotoMetricDistortion'), + dict( + type='RandomAffine', + max_rotate_degree=30.0, + # max_translate_ratio=0., + # scaling_ratio_range=(1., 1.), + # max_shear_degree=0., + scaling_ratio_range=(0.75, 1.0), + ), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=list(zip(range(400, 1401, 8), (1400, ) * 126)), + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=list(zip(range(400, 1401, 8), (1400, ) * 126)), + keep_ratio=True) + ] + ]), + dict(type='GenerateHeatmap'), + dict( + type='PackDetPoseInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape')) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='PoseToDetConverter'), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict( + type='PackDetPoseInputs', + meta_keys=('id', 'img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip_indices')) +] + +dataset_type = 'CocoDataset' +data_mode = 'bottomup' +data_root = 'data/coco/' + +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='AspectRatioBatchSampler'), + dataset=dict( + type=dataset_type, + data_mode=data_mode, + data_root=data_root, + ann_file='annotations/person_keypoints_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline)) + +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_mode=data_mode, + data_root=data_root, + ann_file='annotations/person_keypoints_val2017.json', + data_prefix=dict(img='val2017/'), + test_mode=True, + pipeline=test_pipeline)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='mmpose.CocoMetric', + ann_file=data_root + 'annotations/person_keypoints_val2017.json', + nms_mode='none', + score_mode='bbox') +test_evaluator = val_evaluator + +default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) diff --git a/projects/petr/configs/petr_swin-l_8xb4-100e_coco.py b/projects/petr/configs/petr_swin-l_8xb4-100e_coco.py new file mode 100644 index 0000000000..5bc8ad8f1e --- /dev/null +++ b/projects/petr/configs/petr_swin-l_8xb4-100e_coco.py @@ -0,0 +1,24 @@ +_base_ = ['petr_r50_8xb4-100e_coco.py'] + +# model +model = dict( + backbone=dict( + _delete_=True, + type='SwinTransformer', + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=7, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.5, + patch_norm=True, + out_indices=(1, 2, 3), + with_cp=False, + convert_weights=True), + neck=dict(in_channels=[384, 768, 1536])) + +optim_wrapper = dict(optimizer=dict(lr=0.0001)) diff --git a/projects/petr/datasets/__init__.py b/projects/petr/datasets/__init__.py new file mode 100644 index 0000000000..69bae9de53 --- /dev/null +++ b/projects/petr/datasets/__init__.py @@ -0,0 +1,3 @@ +from .bbox_keypoint_structure import * # noqa +from .coco_dataset import * # noqa +from .transforms import * # noqa diff --git a/projects/petr/datasets/bbox_keypoint_structure.py b/projects/petr/datasets/bbox_keypoint_structure.py new file mode 100644 index 0000000000..6b385f2f09 --- /dev/null +++ b/projects/petr/datasets/bbox_keypoint_structure.py @@ -0,0 +1,285 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import List, Optional, Sequence, Tuple, Type, TypeVar, Union + +import numpy as np +import torch +from mmdet.structures.bbox import HorizontalBoxes +from torch import Tensor + +DeviceType = Union[str, torch.device] +T = TypeVar('T') +IndexType = Union[slice, int, list, torch.LongTensor, torch.cuda.LongTensor, + torch.BoolTensor, torch.cuda.BoolTensor, np.ndarray] + + +class BBoxKeypoints(HorizontalBoxes): + """The BBoxKeypoints class is a combination of bounding boxes and keypoints + representation. The box format used in BBoxKeypoints is the same as + HorizontalBoxes. + + Args: + data (Tensor or np.ndarray): The box data with shape of + (N, 4). + keypoints (Tensor or np.ndarray): The keypoint data with shape of + (N, K, 2). + keypoints_visible (Tensor or np.ndarray): The visibility of keypoints + with shape of (N, K). + dtype (torch.dtype, Optional): data type of boxes. Defaults to None. + device (str or torch.device, Optional): device of boxes. + Default to None. + clone (bool): Whether clone ``boxes`` or not. Defaults to True. + mode (str, Optional): the mode of boxes. If it is 'cxcywh', the + `data` will be converted to 'xyxy' mode. Defaults to None. + flip_indices (list, Optional): The indices of keypoints when the + images is flipped. Defaults to None. + + Notes: + N: the number of instances. + K: the number of keypoints. + """ + + def __init__(self, + data: Union[Tensor, np.ndarray], + keypoints: Union[Tensor, np.ndarray], + keypoints_visible: Union[Tensor, np.ndarray], + dtype: Optional[torch.dtype] = None, + device: Optional[DeviceType] = None, + clone: bool = True, + in_mode: Optional[str] = None, + flip_indices: Optional[List] = None) -> None: + + super().__init__( + data=data, + dtype=dtype, + device=device, + clone=clone, + in_mode=in_mode) + + assert len(data) == len(keypoints) + assert len(data) == len(keypoints_visible) + + assert keypoints.ndim == 3 + assert keypoints_visible.ndim == 2 + + keypoints = torch.as_tensor(keypoints) + keypoints_visible = torch.as_tensor(keypoints_visible) + + if device is not None: + keypoints = keypoints.to(device=device) + keypoints_visible = keypoints_visible.to(device=device) + + if clone: + keypoints = keypoints.clone() + keypoints_visible = keypoints_visible.clone() + + self.keypoints = keypoints + self.keypoints_visible = keypoints_visible + self.flip_indices = flip_indices + + def flip_(self, + img_shape: Tuple[int, int], + direction: str = 'horizontal') -> None: + """Flip boxes & kpts horizontally in-place. + + Args: + img_shape (Tuple[int, int]): A tuple of image height and width. + direction (str): Flip direction, options are "horizontal", + "vertical" and "diagonal". Defaults to "horizontal" + """ + assert direction == 'horizontal' + super().flip_(img_shape, direction) + self.keypoints[..., 0] = img_shape[1] - self.keypoints[..., 0] + self.keypoints = self.keypoints[:, self.flip_indices] + self.keypoints_visible = self.keypoints_visible[:, self.flip_indices] + + def translate_(self, distances: Tuple[float, float]) -> None: + """Translate boxes and keypoints in-place. + + Args: + distances (Tuple[float, float]): translate distances. The first + is horizontal distance and the second is vertical distance. + """ + boxes = self.tensor + assert len(distances) == 2 + self.tensor = boxes + boxes.new_tensor(distances).repeat(2) + distances = self.keypoints.new_tensor(distances).reshape(1, 1, 2) + self.keypoints = self.keypoints + distances + + def rescale_(self, scale_factor: Tuple[float, float]) -> None: + """Rescale boxes & keypoints w.r.t. rescale_factor in-place. + + Note: + Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes + w.r.t ``scale_facotr``. The difference is that ``resize_`` only + changes the width and the height of boxes, but ``rescale_`` also + rescales the box centers simultaneously. + + Args: + scale_factor (Tuple[float, float]): factors for scaling boxes. + The length should be 2. + """ + boxes = self.tensor + assert len(scale_factor) == 2 + + self.tensor = boxes * boxes.new_tensor(scale_factor).repeat(2) + scale_factor = self.keypoints.new_tensor(scale_factor).reshape(1, 1, 2) + self.keypoints = self.keypoints * scale_factor + + def clip_(self, img_shape: Tuple[int, int]) -> None: + """Clip bounding boxes and set invisible keypoints outside the image + boundary in-place. + + Args: + img_shape (Tuple[int, int]): A tuple of image height and width. + """ + boxes = self.tensor + boxes[..., 0::2] = boxes[..., 0::2].clamp(0, img_shape[1]) + boxes[..., 1::2] = boxes[..., 1::2].clamp(0, img_shape[0]) + + kpt_outside = torch.logical_or( + torch.logical_or(self.keypoints[..., 0] < 0, + self.keypoints[..., 1] < 0), + torch.logical_or(self.keypoints[..., 0] > img_shape[1], + self.keypoints[..., 1] > img_shape[0])) + self.keypoints_visible[kpt_outside] *= 0 + + def project_(self, homography_matrix: Union[Tensor, np.ndarray]) -> None: + """Geometrically transform bounding boxes and keypoints in-place using + a homography matrix. + + Args: + homography_matrix (Tensor or np.ndarray): A 3x3 tensor or ndarray + representing the homography matrix for the transformation. + """ + boxes = self.tensor + if isinstance(homography_matrix, np.ndarray): + homography_matrix = boxes.new_tensor(homography_matrix) + + # Convert boxes to corners in homogeneous coordinates + corners = self.hbox2corner(boxes) + corners = torch.cat( + [corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1) + + # Convert keypoints to homogeneous coordinates + keypoints = torch.cat([ + self.keypoints, + self.keypoints.new_ones(*self.keypoints.shape[:-1], 1) + ], + dim=-1) + + # Transpose corners and keypoints for matrix multiplication + corners_T = torch.transpose(corners, -1, -2) + keypoints_T = torch.transpose(keypoints, -1, 0).contiguous().flatten(1) + + # Apply homography matrix to corners and keypoints + corners_T = torch.matmul(homography_matrix, corners_T) + keypoints_T = torch.matmul(homography_matrix, keypoints_T) + + # Transpose back to original shape + corners = torch.transpose(corners_T, -1, -2) + keypoints_T = keypoints_T.reshape(3, self.keypoints.shape[1], -1) + keypoints = torch.transpose(keypoints_T, -1, 0).contiguous() + + # Convert corners and keypoints back to non-homogeneous coordinates + corners = corners[..., :2] / corners[..., 2:3] + keypoints = keypoints[..., :2] / keypoints[..., 2:3] + + # Convert corners back to bounding boxes and update object attributes + self.tensor = self.corner2hbox(corners) + self.keypoints = keypoints + + @classmethod + def cat(cls: Type[T], box_list: Sequence[T], dim: int = 0) -> T: + """Cancatenates an instance list into one single instance. Similar to + ``torch.cat``. + + Args: + box_list (Sequence[T]): A sequence of instances. + dim (int): The dimension over which the box and keypoint are + concatenated. Defaults to 0. + + Returns: + T: Concatenated instance. + """ + assert isinstance(box_list, Sequence) + if len(box_list) == 0: + raise ValueError('box_list should not be a empty list.') + + assert dim == 0 + assert all(isinstance(boxes, cls) for boxes in box_list) + + th_box_list = torch.cat([boxes.tensor for boxes in box_list], dim=dim) + th_kpt_list = torch.cat([boxes.keypoints for boxes in box_list], + dim=dim) + th_kpt_vis_list = torch.cat( + [boxes.keypoints_visible for boxes in box_list], dim=dim) + flip_indices = box_list[0].flip_indices + return cls( + th_box_list, + th_kpt_list, + th_kpt_vis_list, + clone=False, + flip_indices=flip_indices) + + def __getitem__(self: T, index: IndexType) -> T: + """Rewrite getitem to protect the last dimension shape.""" + boxes = self.tensor + if isinstance(index, np.ndarray): + index = torch.as_tensor(index, device=self.device) + if isinstance(index, Tensor) and index.dtype == torch.bool: + assert index.dim() < boxes.dim() + elif isinstance(index, tuple): + assert len(index) < boxes.dim() + # `Ellipsis`(...) is commonly used in index like [None, ...]. + # When `Ellipsis` is in index, it must be the last item. + if Ellipsis in index: + assert index[-1] is Ellipsis + + boxes = boxes[index] + keypoints = self.keypoints[index] + keypoints_visible = self.keypoints_visible[index] + if boxes.dim() == 1: + boxes = boxes.reshape(1, -1) + keypoints = keypoints.reshape(1, -1, 2) + keypoints_visible = keypoints_visible.reshape(1, -1) + return type(self)( + boxes, + keypoints, + keypoints_visible, + flip_indices=self.flip_indices, + clone=False) + + @property + def num_keypoints(self) -> Tensor: + """Compute the number of visible keypoints for each object.""" + return self.keypoints_visible.sum(dim=1).int() + + def __deepcopy__(self, memo): + """Only clone the tensors when applying deepcopy.""" + cls = self.__class__ + other = cls.__new__(cls) + memo[id(self)] = other + other.tensor = self.tensor.clone() + other.keypoints = self.keypoints.clone() + other.keypoints_visible = self.keypoints_visible.clone() + other.flip_indices = deepcopy(self.flip_indices) + return other + + def clone(self: T) -> T: + """Reload ``clone`` for tensors.""" + return type(self)( + self.tensor, + self.keypoints, + self.keypoints_visible, + flip_indices=self.flip_indices, + clone=True) + + def to(self: T, *args, **kwargs) -> T: + """Reload ``to`` for tensors.""" + return type(self)( + self.tensor.to(*args, **kwargs), + self.keypoints.to(*args, **kwargs), + self.keypoints_visible.to(*args, **kwargs), + flip_indices=self.flip_indices, + clone=False) diff --git a/projects/petr/datasets/coco_dataset.py b/projects/petr/datasets/coco_dataset.py new file mode 100644 index 0000000000..99e0d1120f --- /dev/null +++ b/projects/petr/datasets/coco_dataset.py @@ -0,0 +1,79 @@ +import copy +from typing import Optional + +import numpy as np +from mmdet.registry import DATASETS + +from mmpose.datasets import CocoDataset as MMPoseCocoDataset + + +@DATASETS.register_module(force=True) +class CocoDataset(MMPoseCocoDataset): + + def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: + """Parse raw COCO annotation of an instance. + + Args: + raw_data_info (dict): Raw data information loaded from + ``ann_file``. It should have following contents: + + - ``'raw_ann_info'``: Raw annotation of an instance + - ``'raw_img_info'``: Raw information of the image that + contains the instance + + Returns: + dict | None: Parsed instance annotation + """ + + ann = raw_data_info['raw_ann_info'] + img = raw_data_info['raw_img_info'] + + # filter invalid instance + if 'bbox' not in ann or 'keypoints' not in ann: + return None + + img_w, img_h = img['width'], img['height'] + + # get bbox in shape [1, 4], formatted as xywh + x, y, w, h = ann['bbox'] + x1 = np.clip(x, 0, img_w - 1) + y1 = np.clip(y, 0, img_h - 1) + x2 = np.clip(x + w, 0, img_w - 1) + y2 = np.clip(y + h, 0, img_h - 1) + + bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4) + + # keypoints in shape [1, K, 2] and keypoints_visible in [1, K] + _keypoints = np.array( + ann['keypoints'], dtype=np.float32).reshape(1, -1, 3) + keypoints = _keypoints[..., :2] + keypoints_visible = np.minimum(1, _keypoints[..., 2]) + + if 'num_keypoints' in ann: + num_keypoints = ann['num_keypoints'] + else: + num_keypoints = np.count_nonzero(keypoints.max(axis=2)) + + data_info = { + 'img_id': ann['image_id'], + 'img_path': img['img_path'], + 'width': img_w, + 'height': img_h, + 'bbox': bbox, + 'bbox_score': np.ones(1, dtype=np.float32), + 'num_keypoints': num_keypoints, + 'keypoints': keypoints, + 'keypoints_visible': keypoints_visible, + 'iscrowd': ann.get('iscrowd', 0), + 'segmentation': ann.get('segmentation', None), + 'id': ann['id'], + 'category_id': ann['category_id'], + # store the raw annotation of the instance + # it is useful for evaluation without providing ann_file + 'raw_ann_info': copy.deepcopy(ann), + } + + if 'crowdIndex' in img: + data_info['crowd_index'] = img['crowdIndex'] + + return data_info diff --git a/projects/petr/datasets/transforms.py b/projects/petr/datasets/transforms.py new file mode 100644 index 0000000000..82fd5b0f8c --- /dev/null +++ b/projects/petr/datasets/transforms.py @@ -0,0 +1,169 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Union + +import numpy as np +from mmcv.transforms import BaseTransform +from mmdet.datasets.transforms import PackDetInputs +from mmdet.registry import TRANSFORMS +from mmdet.structures.bbox.box_type import autocast_box_type +from mmengine.structures import PixelData + +from mmpose.codecs.utils import generate_gaussian_heatmaps +from .bbox_keypoint_structure import BBoxKeypoints + + +@TRANSFORMS.register_module(force=True) +class PoseToDetConverter(BaseTransform): + """This transform converts the pose data element into a format that is + suitable for the mmdet transforms.""" + + def transform(self, results: dict) -> dict: + + results['seg_map_path'] = None + results['height'] = results['img_shape'][0] + results['width'] = results['img_shape'][1] + + num_instances = len(results.get('bbox', [])) + + if num_instances == 0: + results['bbox'] = np.empty((0, 4), dtype=np.float32) + results['keypoints'] = np.empty( + (0, len(results['flip_indices']), 2), dtype=np.float32) + results['keypoints_visible'] = np.empty( + (0, len(results['flip_indices'])), dtype=np.int32) + results['category_id'] = [] + + results['gt_bboxes'] = BBoxKeypoints( + data=results['bbox'], + keypoints=results['keypoints'], + keypoints_visible=results['keypoints_visible'], + flip_indices=results['flip_indices'], + ) + + results['gt_ignore_flags'] = np.array([False] * num_instances) + results['gt_bboxes_labels'] = np.array(results['category_id']) - 1 + + return results + + +@TRANSFORMS.register_module(force=True) +class PackDetPoseInputs(PackDetInputs): + mapping_table = { + 'gt_bboxes': 'bboxes', + 'gt_bboxes_labels': 'labels', + 'gt_masks': 'masks', + 'gt_keypoints': 'keypoints', + 'gt_keypoints_visible': 'keypoints_visible' + } + field_mapping_table = { + 'gt_heatmaps': 'gt_heatmaps', + } + + def __init__(self, + meta_keys=('id', 'img_id', 'img_path', 'ori_shape', + 'img_shape', 'scale_factor', 'flip', + 'flip_direction', 'flip_indices', 'raw_ann_info'), + pack_transformed=False): + self.meta_keys = meta_keys + + def transform(self, results: dict) -> dict: + results['gt_keypoints'] = results['gt_bboxes'].keypoints + results['gt_keypoints_visible'] = results[ + 'gt_bboxes'].keypoints_visible + + # pack fields + gt_fields = None + for key, packed_key in self.field_mapping_table.items(): + if key in results: + + if gt_fields is None: + gt_fields = PixelData() + else: + assert isinstance( + gt_fields, PixelData + ), 'Got mixed single-level and multi-level pixel data.' + + gt_fields.set_field(results[key], packed_key) + + results = super().transform(results) + if gt_fields: + results['data_samples'].gt_fields = gt_fields.to_tensor() + + return results + + +@TRANSFORMS.register_module(force=True) +class GenerateHeatmap(BaseTransform): + + def _get_instance_wise_sigmas(self, + bbox: np.ndarray, + heatmap_min_overlap: float = 0.9 + ) -> np.ndarray: + """Get sigma values for each instance according to their size. + + Args: + bbox (np.ndarray): Bounding box in shape (N, 4, 2) + + Returns: + np.ndarray: Array containing the sigma values for each instance. + """ + sigmas = np.zeros((bbox.shape[0], ), dtype=np.float32) + + heights = bbox[:, 3] - bbox[:, 1] + widths = bbox[:, 2] - bbox[:, 0] + + for i in range(bbox.shape[0]): + h, w = heights[i], widths[i] + + # compute sigma for each instance + # condition 1 + a1, b1 = 1, h + w + c1 = w * h * (1 - heatmap_min_overlap) / (1 + heatmap_min_overlap) + sq1 = np.sqrt(b1**2 - 4 * a1 * c1) + r1 = (b1 + sq1) / 2 + + # condition 2 + a2 = 4 + b2 = 2 * (h + w) + c2 = (1 - heatmap_min_overlap) * w * h + sq2 = np.sqrt(b2**2 - 4 * a2 * c2) + r2 = (b2 + sq2) / 2 + + # condition 3 + a3 = 4 * heatmap_min_overlap + b3 = -2 * heatmap_min_overlap * (h + w) + c3 = (heatmap_min_overlap - 1) * w * h + sq3 = np.sqrt(b3**2 - 4 * a3 * c3) + r3 = (b3 + sq3) / 2 + + sigmas[i] = min(r1, r2, r3) / 3 + + return sigmas + + @autocast_box_type() + def transform(self, results: dict) -> Union[dict, None]: + """Transform function to filter annotations. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + assert 'gt_bboxes' in results + + bbox = results['gt_bboxes'].tensor.numpy() / 8 + keypoints = results['gt_bboxes'].keypoints.numpy() / 8 + keypoints_visible = results['gt_bboxes'].keypoints_visible.numpy() + + heatmap_size = [ + results['img_shape'][1] // 8 + 1, results['img_shape'][0] // 8 + 1 + ] + sigmas = self._get_instance_wise_sigmas(bbox) + + hm, _ = generate_gaussian_heatmaps(heatmap_size, keypoints, + keypoints_visible, sigmas) + + results['gt_heatmaps'] = hm + + return results diff --git a/projects/petr/models/__init__.py b/projects/petr/models/__init__.py new file mode 100644 index 0000000000..e9b7e5f0c1 --- /dev/null +++ b/projects/petr/models/__init__.py @@ -0,0 +1,4 @@ +from .losses import * # noqa +from .match_costs import * # noqa +from .petr import * # noqa +from .petr_head import * # noqa diff --git a/projects/petr/models/losses.py b/projects/petr/models/losses.py new file mode 100644 index 0000000000..8e1601e51f --- /dev/null +++ b/projects/petr/models/losses.py @@ -0,0 +1,89 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn as nn +from mmdet.registry import MODELS +from torch import Tensor + +from mmpose.datasets.datasets.utils import parse_pose_metainfo + + +@MODELS.register_module(force=True) +class OksLoss(nn.Module): + """A PyTorch implementation of the Object Keypoint Similarity (OKS) loss as + described in the paper "YOLO-Pose: Enhancing YOLO for Multi Person Pose + Estimation Using Object Keypoint Similarity Loss" by Debapriya et al. + (2022). + + The OKS loss is used for keypoint-based object recognition and consists + of a measure of the similarity between predicted and ground truth + keypoint locations, adjusted by the size of the object in the image. + + The loss function takes as input the predicted keypoint locations, the + ground truth keypoint locations, a mask indicating which keypoints are + valid, and bounding boxes for the objects. + + Args: + metainfo (Optional[str]): Path to a JSON file containing information + about the dataset's annotations. + loss_weight (float): Weight for the loss. + """ + + def __init__(self, + metainfo: Optional[str] = None, + loss_weight: float = 1.0): + super().__init__() + + if metainfo is not None: + metainfo = parse_pose_metainfo(dict(from_file=metainfo)) + sigmas = metainfo.get('sigmas', None) + if sigmas is not None: + self.register_buffer('sigmas', torch.as_tensor(sigmas)) + self.loss_weight = loss_weight + + def forward(self, + output: Tensor, + target: Tensor, + target_weights: Tensor, + bboxes: Optional[Tensor] = None) -> Tensor: + oks = self.compute_oks(output, target, target_weights, bboxes) + loss = 1 - oks + return loss.mean() * self.loss_weight + + def compute_oks(self, + output: Tensor, + target: Tensor, + target_weights: Tensor, + bboxes: Optional[Tensor] = None) -> Tensor: + """Calculates the OKS loss. + + Args: + output (Tensor): Predicted keypoints in shape N x k x 2, where N + is batch size, k is the number of keypoints, and 2 are the + xy coordinates. + target (Tensor): Ground truth keypoints in the same shape as + output. + target_weights (Tensor): Mask of valid keypoints in shape N x k, + with 1 for valid and 0 for invalid. + bboxes (Optional[Tensor]): Bounding boxes in shape N x 4, + where 4 are the xyxy coordinates. + + Returns: + Tensor: The calculated OKS loss. + """ + + dist = torch.norm(output - target, dim=-1) + + if hasattr(self, 'sigmas'): + sigmas = self.sigmas.reshape(*((1, ) * (dist.ndim - 1)), -1) + if sigmas.device != dist.device: + sigmas = sigmas.to(dist.device) + dist = dist / sigmas + if bboxes is not None: + area = torch.prod( + bboxes[..., 2:] - bboxes[..., :2], dim=-1).pow(0.5) + dist = dist / area.clip(min=1e-8).unsqueeze(-1) + + return (torch.exp(-dist.pow(2) / 2) * target_weights).sum( + dim=-1) / target_weights.sum(dim=-1).clip(min=1e-8) diff --git a/projects/petr/models/match_costs.py b/projects/petr/models/match_costs.py new file mode 100644 index 0000000000..0930ef9a83 --- /dev/null +++ b/projects/petr/models/match_costs.py @@ -0,0 +1,63 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +from mmdet.models.task_modules.assigners.match_cost import BaseMatchCost +from mmdet.registry import TASK_UTILS +from mmengine.structures import InstanceData +from torch import Tensor + +from .losses import OksLoss + + +@TASK_UTILS.register_module() +class KptL1Cost(BaseMatchCost): + + def __call__(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> Tensor: + + pred_keypoints = pred_instances.keypoints + gt_keypoints = gt_instances.keypoints + + # normalized + img_h, img_w = img_meta['img_shape'] + factor = gt_keypoints.new_tensor([img_w, img_h]).reshape(1, 1, 2) + gt_keypoints = (gt_keypoints / factor).flatten(1) + pred_keypoints = (pred_keypoints / factor).flatten(1) + + kpt_cost = torch.cdist(pred_keypoints, gt_keypoints, p=1) + return kpt_cost * self.weight + + +@TASK_UTILS.register_module() +class OksCost(BaseMatchCost, OksLoss): + + def __init__(self, metainfo: Optional[str] = None, weight: float = 1.0): + OksLoss.__init__(self, metainfo, weight) + self.weight = self.loss_weight + + def __call__(self, + pred_instances: InstanceData, + gt_instances: InstanceData, + img_meta: Optional[dict] = None, + **kwargs) -> Tensor: + + pred_keypoints = pred_instances.keypoints + gt_keypoints = gt_instances.keypoints + gt_bboxes = gt_instances.bboxes + gt_keypoints_visible = gt_instances.keypoints_visible + + # normalized + img_h, img_w = img_meta['img_shape'] + factor = gt_keypoints.new_tensor([img_w, img_h]).reshape(1, 1, 2) + gt_keypoints = (gt_keypoints / factor).unsqueeze(0) + pred_keypoints = (pred_keypoints / factor).unsqueeze(1) + gt_bboxes = (gt_bboxes.reshape(-1, 2, 2) / factor).reshape(1, -1, 4) + + kpt_cost = self.compute_oks(pred_keypoints, gt_keypoints, + gt_keypoints_visible, gt_bboxes) + kpt_cost = -kpt_cost + return kpt_cost * self.weight diff --git a/projects/petr/models/petr.py b/projects/petr/models/petr.py new file mode 100644 index 0000000000..11de5a0d57 --- /dev/null +++ b/projects/petr/models/petr.py @@ -0,0 +1,716 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import Dict, Tuple, Union + +import torch +import torch.nn.functional as F +from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention +from mmdet.models.detectors import DeformableDETR +from mmdet.models.layers import (DeformableDetrTransformerDecoder, + DeformableDetrTransformerEncoder, + SinePositionalEncoding) +from mmdet.models.layers.transformer.utils import inverse_sigmoid +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList, SampleList +from mmengine.model import xavier_init +from torch import Tensor, nn +from torch.nn.init import normal_ + +from .transformers import PetrTransformerDecoder + + +@MODELS.register_module() +class PETR(DeformableDETR): + + def __init__(self, + num_keypoints: int = 17, + hm_encoder: dict = None, + kpt_decoder: dict = None, + *args, + **kwargs): + self.num_keypoints = num_keypoints + self.hm_encoder = hm_encoder + self.kpt_decoder = kpt_decoder + super().__init__(*args, **kwargs) + + self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook) + self.encoder._register_load_state_dict_pre_hook( + self._load_state_dict_pre_hook) + self.decoder._register_load_state_dict_pre_hook( + self._load_state_dict_pre_hook) + + def _init_layers(self) -> None: + """Initialize layers except for backbone, neck and bbox_head.""" + self.positional_encoding = SinePositionalEncoding( + **self.positional_encoding) + self.encoder = DeformableDetrTransformerEncoder(**self.encoder) + self.decoder = PetrTransformerDecoder(**self.decoder) + self.hm_encoder = DeformableDetrTransformerEncoder(**self.hm_encoder) + self.kpt_decoder = DeformableDetrTransformerDecoder(**self.kpt_decoder) + self.embed_dims = self.encoder.embed_dims + self.query_embedding = nn.Embedding(self.num_queries, + self.embed_dims * 2) + self.kpt_query_embedding = nn.Embedding(self.num_keypoints, + self.embed_dims * 2) + + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, \ + 'embed_dims should be exactly 2 times of num_feats. ' \ + f'Found {self.embed_dims} and {num_feats}.' + + self.level_embed = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims)) + + if self.as_two_stage: + self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims) + self.memory_trans_norm = nn.LayerNorm(self.embed_dims) + else: + self.reference_points_fc = nn.Linear(self.embed_dims, 2) + + def init_weights(self) -> None: + """Initialize weights for Transformer and other components.""" + super(DeformableDETR, self).init_weights() + for coder in self.encoder, self.decoder: + for p in coder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + if self.as_two_stage: + nn.init.xavier_uniform_(self.memory_trans_fc.weight) + else: + xavier_init( + self.reference_points_fc, distribution='uniform', bias=0.) + normal_(self.level_embed) + + def forward_transformer(self, + img_feats: Tuple[Tensor], + batch_data_samples: OptSampleList = None, + test_mode: bool = True) -> Dict: + """Forward process of Transformer, which includes four steps: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'. We + summarized the parameters flow of the existing DETR-like detector, + which can be illustrated as follow: + + .. code:: text + + img_feats & batch_data_samples + | + V + +-----------------+ + | pre_transformer | + +-----------------+ + | | + | V + | +-----------------+ + | | forward_encoder | + | +-----------------+ + | | + | V + | +---------------+ + | | pre_decoder | + | +---------------+ + | | | + V V | + +-----------------+ | + | forward_decoder | | + +-----------------+ | + | | + V V + head_inputs_dict + + Args: + img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each + feature map has shape (bs, dim, H, W). + batch_data_samples (list[:obj:`DetDataSample`], optional): The + batch data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + Defaults to None. + + Returns: + dict: The dictionary of bbox_head function inputs, which always + includes the `hidden_states` of the decoder output and may contain + `references` including the initial and intermediate references. + """ + encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( + img_feats, batch_data_samples) + # encoder_inputs_dict + # |- feat: [bs, mlv_shape, 256] + # |- feat_mask: [bs, mlv_shape] + # |- feat_pos: [bs, mlv_shape, 256] + # |- spatial_shapes: [4, 2] + # |- level_start_index: [4] + # |- valid_ratios: [bs, 4, 2] + # decoder_inputs_dict + # |- memory_mask: [bs, mlv_shape] + # |- spatial_shapes [4, 2] + # |- level_start_index [4] + # |- valid_ratios [bs, 4, 2] + + encoder_outputs_dict, heatmap_dict = self.forward_encoder( + **encoder_inputs_dict, test_mode=test_mode) + # encoder_outputs_dict + # |- memory: [bs, mlv_shape, 256] + # |- memory_mask: [bs, mlv_shape] (feat_mask) + # |- spatial_shapes: [4, 2] + # heatmap_dict + # |- hm_memory: [bs, lv0_h, lv0_w, 256] + # |- hm_mask: [bs, lv0_h, lv0_w] + + tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict) + # tmp_dec_in + # |- query: [bs, num_queries, 256] + # |- query_pos: [bs, num_queries, 256] + # |- memory: [bs, mlv_shape, 256] (memory) + # |- reference_points: [bs, num_queries, 2*num_keypoints] + # head_inputs_dict (train only) + # |- enc_outputs_class: [bs, mlv_shape, 1] + # |- enc_outputs_coord: [bs, mlv_shape, 34] + + decoder_inputs_dict.update(tmp_dec_in) + # decoder_inputs_dict + # |- query: [bs, num_queries, 256] + # |- query_pos: [bs, num_queries, 256] + # |- memory: [bs, mlv_shape, 256] (memory) + # |- memory_mask: [bs, mlv_shape] + # |- spatial_shapes [4, 2] + # |- level_start_index [4] + # |- valid_ratios [bs, 4, 2] + + decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) + # decoder_outputs_dict + # |- hidden_states [3, bs, num_queries, 256] + # |- references [1, 300, 34] * 4 + # |- all_layers_classes [3, bs, num_queries, 1] + # |- all_layers_coords [3, bs, num_queries, 2*num_keypoints] + + kpt_decoder_inputs_dict = self.pre_kpt_decoder( + **decoder_outputs_dict, + batch_data_samples=batch_data_samples, + test_mode=test_mode) + # kpt_decoder_inputs_dict + # |- pos_kpt_coords: [max_inst, 2*num_keypoints] + # |- pos_img_inds: [max_inst] + # |- det_labels: [max_inst] (test only) + # |- det_scores: [max_inst] (test only) + + kpt_decoder_inputs_dict.update(decoder_inputs_dict) + # kpt_decoder_inputs_dict + # |- pos_kpt_coords: [max_inst, 2*num_keypoints] + # |- pos_img_inds: [max_inst] + # |- det_labels: [max_inst] + # |- query: [bs, num_queries, 256] + # |- query_pos: [bs, num_queries, 256] + # |- memory: [bs, mlv_shape, 256] (memory) + # |- memory_mask: [bs, mlv_shape] + # |- spatial_shapes [4, 2] + # |- level_start_index [4] + # |- valid_ratios [bs, 4, 2] + + kpt_decoder_outputs_dict = self.forward_kpt_decoder( + **kpt_decoder_inputs_dict) + # kpt_decoder_outputs_dict (test) + # |- inter_states: [2, max_inst, num_keypoints, 256] + # |- reference_points: [max_inst, num_keypoints, 2] + # |- inter_references: [2, max_inst, num_keypoints, 2] + + dec_outputs_coord = self.forward_kpt_head(**kpt_decoder_outputs_dict) + # dec_outputs_coord: [2, max_inst, num_keypoints, 2] + + head_inputs_dict['dec_outputs_coord'] = dec_outputs_coord + if test_mode: + head_inputs_dict['det_labels'] = kpt_decoder_inputs_dict[ + 'det_labels'] + head_inputs_dict['det_scores'] = kpt_decoder_inputs_dict[ + 'det_scores'] + else: + head_inputs_dict.update(heatmap_dict) + head_inputs_dict['all_layers_classes'] = decoder_outputs_dict[ + 'all_layers_classes'] + head_inputs_dict['all_layers_coords'] = decoder_outputs_dict[ + 'all_layers_coords'] + # head_inputs_dict + # |- enc_outputs_class: [bs, mlv_shape, 1] (train only) + # |- enc_outputs_coord: [bs, mlv_shape, 34] (train only) + # |- all_layers_classes [3, bs, num_queries, 1] (train only) + # |- all_layers_coords [3, bs, num_queries, 2*num_keypoints] (to) + # |- hm_memory: [bs, lv0_h, lv0_w, 256] (train only) + # |- hm_mask: [bs, lv0_h, lv0_w] (train only) + # |- dec_outputs_coord: [2, max_inst, num_keypoints, 2] + # |- det_labels: [max_inst] (test only) + # |- det_scores: [max_inst] (test only) + + return head_inputs_dict + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs (Tensor): Input images of shape (bs, dim, H, W). + These should usually be mean centered and std scaled. + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components + """ + # torch.save(dict( + # batch_inputs=batch_inputs.cpu(), + # batch_data_samples=batch_data_samples + # ), 'notebooks/train_proc_tensors/img+ds.pth') + # exit(0) + + img_feats = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer( + img_feats, batch_data_samples, test_mode=False) + # head_inputs_dict + # |- enc_outputs_class: [bs, mlv_shape, 1] + # |- enc_outputs_coord: [bs, mlv_shape, 34] + # |- all_layers_classes [3, bs, num_queries, 1] + # |- all_layers_coords [3, bs, num_queries, 2*num_keypoints] + # |- hm_memory: [bs, lv0_h, lv0_w, 256] + # |- hm_mask: [bs, lv0_h, lv0_w] + # |- dec_outputs_coord: [2, max_inst, num_keypoints, 2] + + losses = self.bbox_head.loss( + **head_inputs_dict, batch_data_samples=batch_data_samples) + + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the input images. + Each DetDataSample usually contain 'pred_instances'. And the + `pred_instances` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + + img_feats = self.extract_feat(batch_inputs) + head_inputs_dict = self.forward_transformer( + img_feats, batch_data_samples, test_mode=True) + # head_inputs_dict + # |- dec_outputs_coord: [2, max_inst, num_keypoints, 2] + # |- det_labels: [max_inst] + # |- det_scores: [max_inst] + + results_list = self.bbox_head.predict( + **head_inputs_dict, + rescale=rescale, + batch_data_samples=batch_data_samples) + batch_data_samples = self.add_pred_to_datasample( + batch_data_samples, results_list) + return batch_data_samples + + def forward_encoder(self, + feat: Tensor, + feat_mask: Tensor, + feat_pos: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor, + test_mode: bool = True) -> Dict: + """Forward with Transformer encoder. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + feat (Tensor): Sequential features, has shape (bs, num_feat_points, + dim). + feat_mask (Tensor): ByteTensor, the padding mask of the features, + has shape (bs, num_feat_points). + feat_pos (Tensor): The positional embeddings of the features, has + shape (bs, num_feat_points, dim). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + + Returns: + dict: The dictionary of encoder outputs, which includes the + `memory` of the encoder output. + """ + memory = self.encoder( + query=feat, + query_pos=feat_pos, + key_padding_mask=feat_mask, # for self_attn + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios) + + encoder_outputs_dict = dict( + memory=memory, + memory_mask=feat_mask, + spatial_shapes=spatial_shapes) + + # only for training + heatmap_dict = dict() + if not test_mode: + batch_size = memory.size(0) + hm_memory = memory[:, level_start_index[0]:level_start_index[1], :] + hm_pos_embed = feat_pos[:, level_start_index[0]: + level_start_index[1], :] + hm_mask = feat_mask[:, level_start_index[0]:level_start_index[1]] + hm_memory = self.hm_encoder( + query=hm_memory, + query_pos=hm_pos_embed, + query_pos=None, + key_padding_mask=hm_mask, + spatial_shapes=spatial_shapes.narrow(0, 0, 1), + level_start_index=level_start_index[0], + valid_ratios=valid_ratios.narrow(1, 0, 1)) + hm_memory = hm_memory.reshape(batch_size, spatial_shapes[0, 0], + spatial_shapes[0, 1], -1) + hm_mask = hm_mask.reshape(batch_size, spatial_shapes[0, 0], + spatial_shapes[0, 1]) + + heatmap_dict['hm_memory'] = hm_memory + heatmap_dict['hm_mask'] = hm_mask + + return encoder_outputs_dict, heatmap_dict + + def pre_decoder(self, memory: Tensor, memory_mask: Tensor, + spatial_shapes: Tensor) -> Tuple[Dict, Dict]: + """Prepare intermediate variables before entering Transformer decoder, + such as `query`, `query_pos`, and `reference_points`. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). It will only be used when + `as_two_stage` is `True`. + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + It will only be used when `as_two_stage` is `True`. + + Returns: + tuple[dict, dict]: The decoder_inputs_dict and head_inputs_dict. + + - decoder_inputs_dict (dict): The keyword dictionary args of + `self.forward_decoder()`, which includes 'query', 'query_pos', + 'memory', and `reference_points`. The reference_points of + decoder input here are 4D boxes when `as_two_stage` is `True`, + otherwise 2D points, although it has `points` in its name. + The reference_points in encoder is always 2D points. + - head_inputs_dict (dict): The keyword dictionary args of the + bbox_head functions, which includes `enc_outputs_class` and + `enc_outputs_coord`. They are both `None` when 'as_two_stage' + is `False`. The dict is empty when `self.training` is `False`. + """ + batch_size, _, c = memory.shape + + query_embed = self.query_embedding.weight + query_pos, query = torch.split(query_embed, c, dim=1) + query_pos = query_pos.unsqueeze(0).expand(batch_size, -1, -1) + query = query.unsqueeze(0).expand(batch_size, -1, -1) + + if self.as_two_stage: + output_memory, output_proposals = \ + self.gen_encoder_output_proposals( + memory, memory_mask, spatial_shapes) + enc_outputs_class = self.bbox_head.cls_branches[ + self.decoder.num_layers]( + output_memory) + enc_outputs_coord_unact = self.bbox_head.reg_branches[ + self.decoder.num_layers]( + output_memory) + enc_outputs_coord_unact[..., 0::2] += output_proposals[..., 0:1] + enc_outputs_coord_unact[..., 1::2] += output_proposals[..., 1:2] + enc_outputs_coord = enc_outputs_coord_unact.sigmoid() + # We only use the first channel in enc_outputs_class as foreground, + # the other (num_classes - 1) channels are actually not used. + # Its targets are set to be 0s, which indicates the first + # class (foreground) because we use [0, num_classes - 1] to + # indicate class labels, background class is indicated by + # num_classes (similar convention in RPN). + # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa + # This follows the official implementation of Deformable DETR. + topk_proposals = torch.topk( + enc_outputs_class[..., 0], self.num_queries, dim=1)[1] + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, + topk_proposals.unsqueeze(-1).repeat( + 1, 1, enc_outputs_coord_unact.size(-1))) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + else: + enc_outputs_class, enc_outputs_coord = None, None + reference_points = self.reference_points_fc(query_pos).sigmoid() + + decoder_inputs_dict = dict( + query=query, + query_pos=query_pos, + memory=memory, + reference_points=reference_points) + head_inputs_dict = dict( + enc_outputs_class=enc_outputs_class, + enc_outputs_coord=enc_outputs_coord) if self.training else dict() + return decoder_inputs_dict, head_inputs_dict + + def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor, + memory_mask: Tensor, reference_points: Tensor, + spatial_shapes: Tensor, level_start_index: Tensor, + valid_ratios: Tensor) -> Dict: + """Forward with Transformer decoder. + + The forward procedure of the transformer is defined as: + 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' + More details can be found at `TransformerDetector.forward_transformer` + in `mmdet/detector/base_detr.py`. + + Args: + query (Tensor): The queries of decoder inputs, has shape + (bs, num_queries, dim). + query_pos (Tensor): The positional queries of decoder inputs, + has shape (bs, num_queries, dim). + memory (Tensor): The output embeddings of the Transformer encoder, + has shape (bs, num_feat_points, dim). + memory_mask (Tensor): ByteTensor, the padding mask of the memory, + has shape (bs, num_feat_points). + reference_points (Tensor): The initial reference, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has + shape (bs, num_queries, 2) with the last dimension arranged as + (cx, cy). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + + Returns: + dict: The dictionary of decoder outputs, which includes the + `hidden_states` of the decoder output and `references` including + the initial and intermediate reference_points. + """ + inter_states, inter_references = self.decoder( + query=query, + value=memory, + query_pos=query_pos, + key_padding_mask=memory_mask, # for cross_attn + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=self.bbox_head.reg_branches + if self.with_box_refine else None) + references = [reference_points, *inter_references] + + all_layers_classes, all_layers_coords = self.bbox_head( + hidden_states=inter_states, references=references) + + decoder_outputs_dict = dict( + hidden_states=inter_states, + references=references, + all_layers_classes=all_layers_classes, + all_layers_coords=all_layers_coords) + + return decoder_outputs_dict + + def pre_kpt_decoder(self, + all_layers_classes, + all_layers_coords, + batch_data_samples, + test_mode=False, + **kwargs): + + cls_scores = all_layers_classes[-1] + kpt_coords = all_layers_coords[-1] + + if test_mode: + assert cls_scores.size(0) == 1, \ + f'only `batch_size=1` is supported in testing, but got ' \ + f'{cls_scores.size(0)}' + + cls_scores = cls_scores[0] + kpt_coords = kpt_coords[0] + + max_per_img = self.test_cfg['max_per_img'] + if self.bbox_head.loss_cls.use_sigmoid: + cls_scores = cls_scores.sigmoid() + scores, indices = cls_scores.view(-1).topk(max_per_img) + det_labels = indices % self.bbox_head.num_classes + bbox_index = indices // self.bbox_head.num_classes + kpt_coords = kpt_coords[bbox_index] + else: + scores, det_labels = F.softmax( + cls_scores, dim=-1)[..., :-1].max(-1) + scores, bbox_index = scores.topk(max_per_img) + kpt_coords = kpt_coords[bbox_index] + det_labels = det_labels[bbox_index] + + kpt_weights = torch.ones_like(kpt_coords) + + kpt_decoder_inputs_dict = dict( + det_labels=det_labels, + det_scores=scores, + ) + + else: + + batch_gt_instances = [ds.gt_instances for ds in batch_data_samples] + batch_img_metas = [ds.metainfo for ds in batch_data_samples] + + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + kpt_coords_list = [ + kpt_coords[i].reshape(-1, self.bbox_head.num_keypoints, 2) + for i in range(num_imgs) + ] + + cls_reg_targets = self.bbox_head.get_targets( + cls_scores_list, + kpt_coords_list, + batch_gt_instances, + batch_img_metas, + cache_targets=True) + kpt_weights = torch.cat(cls_reg_targets[4]) + kpt_coords = kpt_coords.flatten(0, 1) + kpt_decoder_inputs_dict = {} + + pos_inds = kpt_weights.sum(-1) > 0 + if pos_inds.sum() == 0: + pos_kpt_coords = torch.zeros_like(kpt_coords[:1]) + pos_img_inds = kpt_coords.new_zeros([1], dtype=torch.int64) + else: + pos_kpt_coords = kpt_coords[pos_inds] + pos_img_inds = (pos_inds.nonzero() / + self.num_queries).squeeze(1).to(torch.int64) + + kpt_decoder_inputs_dict.update( + dict( + pos_kpt_coords=pos_kpt_coords, + pos_img_inds=pos_img_inds, + )) + return kpt_decoder_inputs_dict + + def forward_kpt_decoder(self, memory, memory_mask, pos_kpt_coords, + pos_img_inds, spatial_shapes, level_start_index, + valid_ratios, **kwargs): + + kpt_query_embedding = self.kpt_query_embedding.weight + query_pos, query = torch.split( + kpt_query_embedding, kpt_query_embedding.size(1) // 2, dim=1) + pos_num = pos_kpt_coords.size(0) + query_pos = query_pos.unsqueeze(0).expand(pos_num, -1, -1) + query = query.unsqueeze(0).expand(pos_num, -1, -1) + reference_points = pos_kpt_coords.reshape(pos_num, + pos_kpt_coords.size(1) // 2, + 2).detach() + pos_memory = memory[pos_img_inds, :, :] + memory_mask = memory_mask[pos_img_inds, :] + valid_ratios = valid_ratios[pos_img_inds, ...] + + # forward_kpt_decoder + inter_states, inter_references = self.kpt_decoder( + query=query, + key=None, + value=pos_memory, + query_pos=query_pos, + key_padding_mask=memory_mask, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=self.bbox_head.kpt_branches) + + kpt_decoder_outputs_dict = dict( + inter_states=inter_states, + reference_points=reference_points, + inter_references=inter_references, + ) + + return kpt_decoder_outputs_dict + + def forward_kpt_head(self, inter_states, reference_points, + inter_references): + outputs_kpts = [] + for lvl in range(inter_states.shape[0]): + if lvl == 0: + reference = reference_points + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + tmp_kpt = self.bbox_head.kpt_branches[lvl](inter_states[lvl]) + assert reference.shape[-1] == 2 + tmp_kpt += reference + outputs_kpt = tmp_kpt.sigmoid() + outputs_kpts.append(outputs_kpt) + + outputs_kpts = torch.stack(outputs_kpts) + return outputs_kpts + + def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, + **kwargs): + """A hook function to convert the state dict from official repo to a + compatible format :class:`PETR`. + + The hook will be automatically registered during initialization. + """ + if 'mmengine_version' in local_meta: + return + + if local_meta.get('mmdet_verion', '0') > '3': + return + + mappings = OrderedDict() + mappings['bbox_head.transformer.'] = '' + mappings['level_embeds'] = 'level_embed' + mappings['bbox_head.query_embedding'] = 'query_embedding' + mappings['refine_query_embedding'] = 'kpt_query_embedding' + mappings['attentions.0'] = 'self_attn' + mappings['attentions.1'] = 'cross_attn' + mappings['ffns.0'] = 'ffn' + mappings['bbox_head.kpt_branches'] = 'bbox_head.reg_branches' + mappings['bbox_head.refine_kpt_branches'] = 'bbox_head.kpt_branches' + mappings['refine_decoder'] = 'kpt_decoder' + mappings['bbox_head.fc_hm'] = 'bbox_head.heatmap_fc' + mappings['enc_output_norm'] = 'memory_trans_norm' + mappings['enc_output'] = 'memory_trans_fc' + + # convert old-version state dict + for old_key, new_key in mappings.items(): + keys = list(state_dict.keys()) + for k in keys: + if old_key in k: + v = state_dict.pop(k) + k = k.replace(old_key, new_key) + state_dict[k] = v diff --git a/projects/petr/models/petr_head.py b/projects/petr/models/petr_head.py new file mode 100644 index 0000000000..53c297b1b8 --- /dev/null +++ b/projects/petr/models/petr_head.py @@ -0,0 +1,658 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List, Tuple + +import torch +import torch.nn as nn +from mmcv.cnn import Linear +from mmdet.models import inverse_sigmoid +from mmdet.models.dense_heads import DeformableDETRHead +from mmdet.models.utils import multi_apply +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.utils import InstanceList, reduce_mean +from mmengine.structures import InstanceData +from torch import Tensor + + +@MODELS.register_module() +class PETRHead(DeformableDETRHead): + + def __init__(self, + num_keypoints: int = 17, + num_pred_kpt_layer: int = 2, + loss_reg: dict = None, + loss_reg_aux: dict = None, + loss_oks: dict = None, + loss_oks_aux: dict = None, + loss_hm: dict = None, + *args, + **kwargs): + self.num_keypoints = num_keypoints + self.num_pred_kpt_layer = num_pred_kpt_layer + super().__init__(*args, **kwargs) + + self._target_buffer = dict() + + self.loss_reg = MODELS.build(loss_reg) + self.loss_reg_aux = MODELS.build(loss_reg_aux) + self.loss_oks = MODELS.build(loss_oks) + self.loss_oks_aux = MODELS.build(loss_oks_aux) + self.loss_hm = MODELS.build(loss_hm) + + def _init_layers(self) -> None: + """Initialize classification branch and regression branch of head.""" + fc_cls = Linear(self.embed_dims, self.cls_out_channels) + + reg_branch = [Linear(self.embed_dims, self.embed_dims * 2), nn.ReLU()] + for _ in range(self.num_reg_fcs): + reg_branch.append(Linear(self.embed_dims * 2, self.embed_dims * 2)) + reg_branch.append(nn.ReLU()) + reg_branch.append(Linear(self.embed_dims * 2, self.num_keypoints * 2)) + reg_branch = nn.Sequential(*reg_branch) + + kpt_branch = [] + for _ in range(self.num_reg_fcs): + kpt_branch.append(Linear(self.embed_dims, self.embed_dims)) + kpt_branch.append(nn.ReLU()) + kpt_branch.append(Linear(self.embed_dims, 2)) + kpt_branch = nn.Sequential(*kpt_branch) + + if self.share_pred_layer: + self.cls_branches = nn.ModuleList( + [fc_cls for _ in range(self.num_pred_layer)]) + self.reg_branches = nn.ModuleList( + [reg_branch for _ in range(self.num_pred_layer)]) + self.kpt_branches = nn.ModuleList( + [kpt_branch for _ in range(self.num_pred_kpt_layer)]) + else: + self.cls_branches = nn.ModuleList( + [copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)]) + self.reg_branches = nn.ModuleList([ + copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer) + ]) + self.kpt_branches = nn.ModuleList([ + copy.deepcopy(kpt_branch) + for _ in range(self.num_pred_kpt_layer) + ]) + + self.heatmap_fc = Linear(self.embed_dims, self.num_keypoints) + + def forward(self, hidden_states: Tensor, + references: List[Tensor]) -> Tuple[Tensor]: + """Forward function. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, bs, num_queries, dim). + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries, 4) when `as_two_stage` of the detector is `True`, + otherwise (bs, num_queries, 2). Each `inter_reference` has + shape (bs, num_queries, 4) when `with_box_refine` of the + detector is `True`, otherwise (bs, num_queries, 2). The + coordinates are arranged as (cx, cy) when the last dimension is + 2, and (cx, cy, w, h) when it is 4. + + Returns: + tuple[Tensor]: results of head containing the following tensor. + + - all_layers_outputs_classes (Tensor): Outputs from the + classification head, has shape (num_decoder_layers, bs, + num_queries, cls_out_channels). + - all_layers_outputs_coords (Tensor): Sigmoid outputs from the + regression head with normalized coordinate format (cx, cy, w, + h), has shape (num_decoder_layers, bs, num_queries, 4) with the + last dimension arranged as (cx, cy, w, h). + """ + all_layers_outputs_classes = [] + all_layers_outputs_coords = [] + + for layer_id in range(hidden_states.shape[0]): + reference = inverse_sigmoid(references[layer_id]) + # NOTE The last reference will not be used. + hidden_state = hidden_states[layer_id] + outputs_class = self.cls_branches[layer_id](hidden_state) + tmp_reg_preds = self.reg_branches[layer_id](hidden_state) + if reference.shape[-1] == self.num_keypoints * 2: + # When `layer` is 0 and `as_two_stage` of the detector + # is `True`, or when `layer` is greater than 0 and + # `with_box_refine` of the detector is `True`. + tmp_reg_preds += reference + else: + # When `layer` is 0 and `as_two_stage` of the detector + # is `False`, or when `layer` is greater than 0 and + # `with_box_refine` of the detector is `False`. + assert reference.shape[-1] == 2 + tmp_reg_preds[..., :2] += reference + outputs_coord = tmp_reg_preds.sigmoid() + all_layers_outputs_classes.append(outputs_class) + all_layers_outputs_coords.append(outputs_coord) + + all_layers_outputs_classes = torch.stack(all_layers_outputs_classes) + all_layers_outputs_coords = torch.stack(all_layers_outputs_coords) + + return all_layers_outputs_classes, all_layers_outputs_coords + + def predict(self, + dec_outputs_coord: Tensor, + det_labels: Tensor, + det_scores: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + """Perform forward propagation and loss calculation of the detection + head on the queries of the upstream network. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, num_queries, bs, dim). + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries, 4) when `as_two_stage` of the detector is `True`, + otherwise (bs, num_queries, 2). Each `inter_reference` has + shape (bs, num_queries, 4) when `with_box_refine` of the + detector is `True`, otherwise (bs, num_queries, 2). The + coordinates are arranged as (cx, cy) when the last dimension is + 2, and (cx, cy, w, h) when it is 4. + batch_data_samples (list[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + rescale (bool, optional): If `True`, return boxes in original + image space. Defaults to `True`. + + Returns: + list[obj:`InstanceData`]: Detection results of each image + after the post process. + """ + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + assert len(batch_img_metas) == 1, f'PETR only support test with ' \ + f'batch size 1, but got {len(batch_img_metas)}.' + img_meta = batch_img_metas[0] + + if dec_outputs_coord.ndim == 4: + dec_outputs_coord = dec_outputs_coord[-1] + + # filter instance + if self.test_cfg.get('score_thr', .0) > 0: + kept_inds = det_scores > self.test_cfg['score_thr'] + det_labels = det_labels[kept_inds] + det_scores = det_scores[kept_inds] + dec_outputs_coord = dec_outputs_coord[kept_inds] + + if len(dec_outputs_coord) > 0: + # decode keypoints + h, w = img_meta['img_shape'] + if rescale: + h = h / img_meta['scale_factor'][0] + w = w / img_meta['scale_factor'][1] + keypoints = torch.stack([ + dec_outputs_coord[..., 0] * w, + dec_outputs_coord[..., 1] * h, + ], + dim=2) + keypoint_scores = torch.ones(keypoints.shape[:-1]) + + # generate bounding boxes by outlining the detected poses + bboxes = torch.stack([ + keypoints[..., 0].min(dim=1).values.clamp(0, w), + keypoints[..., 1].min(dim=1).values.clamp(0, h), + keypoints[..., 0].max(dim=1).values.clamp(0, w), + keypoints[..., 1].max(dim=1).values.clamp(0, h), + ], + dim=1) + else: + keypoints = torch.empty(0, *dec_outputs_coord.shape[1:]) + keypoint_scores = torch.ones(keypoints.shape[:-1]) + bboxes = torch.empty(0, 4) + + results = InstanceData() + results.set_metainfo(img_meta) + results.bboxes = bboxes + results.scores = det_scores + results.bbox_scores = det_scores + results.labels = det_labels + results.keypoints = keypoints + results.keypoint_scores = keypoint_scores + results = results.numpy() + + return [results] + + def loss(self, enc_outputs_class: Tensor, enc_outputs_coord: Tensor, + all_layers_classes: Tensor, all_layers_coords: Tensor, + hm_memory: Tensor, hm_mask: Tensor, dec_outputs_coord: Tensor, + batch_data_samples: SampleList) -> dict: + """Perform forward propagation and loss calculation of the detection + head on the queries of the upstream network. + + Args: + hidden_states (Tensor): Hidden states output from each decoder + layer, has shape (num_decoder_layers, num_queries, bs, dim). + references (list[Tensor]): List of the reference from the decoder. + The first reference is the `init_reference` (initial) and the + other num_decoder_layers(6) references are `inter_references` + (intermediate). The `init_reference` has shape (bs, + num_queries, 4) when `as_two_stage` of the detector is `True`, + otherwise (bs, num_queries, 2). Each `inter_reference` has + shape (bs, num_queries, 4) when `with_box_refine` of the + detector is `True`, otherwise (bs, num_queries, 2). The + coordinates are arranged as (cx, cy) when the last dimension is + 2, and (cx, cy, w, h) when it is 4. + enc_outputs_class (Tensor): The score of each point on encode + feature map, has shape (bs, num_feat_points, cls_out_channels). + Only when `as_two_stage` is `True` it would be passed in, + otherwise it would be `None`. + enc_outputs_coord (Tensor): The proposal generate from the encode + feature map, has shape (bs, num_feat_points, 4) with the last + dimension arranged as (cx, cy, w, h). Only when `as_two_stage` + is `True` it would be passed in, otherwise it would be `None`. + batch_data_samples (list[:obj:`DetDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. + + Returns: + dict: A dictionary of loss components. + """ + # enc_outputs_class: [bs, mlv_shape, 1] + # enc_outputs_coord: [bs, mlv_shape, 2*num_keypoints] + # all_layers_classes [3, bs, num_queries, 1] + # all_layers_coords [3, bs, num_queries, 2*num_keypoints] + # hm_memory: [bs, lv0_h, lv0_w, 256] + # hm_mask: [bs, lv0_h, lv0_w] + # dec_outputs_coord: [2, max_inst, num_keypoints, 2] + + batch_gt_instances = [] + batch_img_metas = [] + batch_gt_fields = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + batch_gt_fields.append(data_sample.gt_fields) + + loss_dict = dict() + + # calculate loss for decoder output + losses_cls, losses_kpt, losses_oks = multi_apply( + self.loss_by_feat_single, + all_layers_classes, + all_layers_coords, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas, + cache_targets=True) + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_kpt'] = losses_kpt[-1] + loss_dict['loss_oks'] = losses_oks[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_kpt_i, loss_oks_i in zip(losses_cls[:-1], + losses_kpt[:-1], + losses_oks[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_kpt'] = loss_kpt_i + loss_dict[f'd{num_dec_layer}.loss_oks'] = loss_oks_i + num_dec_layer += 1 + + # calculate loss for encoder output + losses_cls, losses_kpt, losses_oks = self.loss_by_feat_single( + enc_outputs_class, + enc_outputs_coord, + batch_gt_instances=batch_gt_instances, + batch_img_metas=batch_img_metas, + cache_targets=False, + compute_oks_loss=False) + loss_dict['enc_loss_cls'] = losses_cls + loss_dict['enc_loss_kpt'] = losses_kpt + + # calculate heatmap loss + loss_hm = self.loss_heatmap( + hm_memory, hm_mask, batch_gt_fields=batch_gt_fields) + loss_dict['loss_hm'] = loss_hm + + # calculate loss for kpt_decoder output + losses_kpt, losses_oks = multi_apply( + self.loss_refined_kpts, + dec_outputs_coord, + batch_img_metas=batch_img_metas, + ) + + num_dec_layer = 0 + for loss_kpt_i, loss_oks_i in zip(losses_kpt, losses_oks): + loss_dict[f'd{num_dec_layer}.loss_kpt_refine'] = loss_kpt_i + loss_dict[f'd{num_dec_layer}.loss_oks_refine'] = loss_oks_i + num_dec_layer += 1 + + self._target_buffer.clear() + return loss_dict + + # TODO: rename this method + def loss_by_feat_single(self, + cls_scores: Tensor, + kpt_preds: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + cache_targets: bool = False, + compute_oks_loss: bool = True) -> Tuple[Tensor]: + """Loss function for outputs from a single decoder layer of a single + feature level. + + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images, has shape (bs, num_queries, cls_out_channels). + kpt_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape (bs, num_queries, 4). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and + `loss_iou`. + """ + # cls_scores [bs, num_queries, 1] + # kpt_preds [bs, num_queries, 2*num_keypoitns] + + num_imgs, num_queries = cls_scores.shape[:2] + + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + kpt_preds_list = [ + kpt_preds[i].reshape(-1, self.num_keypoints, 2) + for i in range(num_imgs) + ] + cls_reg_targets = self.get_targets(cls_scores_list, kpt_preds_list, + batch_gt_instances, batch_img_metas) + (labels_list, label_weights_list, bbox_targets_list, kpt_targets_list, + kpt_weights_list, num_total_pos, num_total_neg) = cls_reg_targets + labels = torch.cat(labels_list, 0) # [bs*300] + label_weights = torch.cat(label_weights_list, 0) # [bs*300] (all 1) + bbox_targets = torch.cat(bbox_targets_list, + 0) # [bs*300, 4] (normalized) + kpt_targets = torch.cat(kpt_targets_list, + 0) # [bs*300, 17, 2] (normalized) + kpt_weights = torch.cat(kpt_weights_list, 0) # [bs*300, 17] + + # keypoint regression loss + kpt_preds = kpt_preds.reshape(-1, self.num_keypoints, 2) + num_valid_kpt = torch.clamp( + reduce_mean(kpt_weights.sum()), min=1).item() + # assert num_valid_kpt == (kpt_targets>0).sum().item() + loss_kpt = self.loss_reg_aux( + kpt_preds, + kpt_targets, + kpt_weights.unsqueeze(-1), + avg_factor=num_valid_kpt) + + # classification loss + cls_scores = cls_scores.reshape(-1, self.cls_out_channels) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor])) + cls_avg_factor = max(cls_avg_factor, 1) + + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + pos_mask = kpt_weights.sum(-1) > 0 + if compute_oks_loss and pos_mask.any().item(): + pos_inds = (pos_mask.nonzero()).div( + num_queries, rounding_mode='trunc').squeeze(-1) + + # construct factors used for rescale keypoints + factors = [] + for img_meta, kpt_pred in zip(batch_img_metas, kpt_preds): + img_h, img_w, = img_meta['img_shape'] + factor = kpt_pred.new_tensor([img_w, img_h]).reshape(1, 1, 2) + factors.append(factor) + factors = torch.cat(factors, 0) + factors = factors[pos_inds] + + # keypoint oks loss + pos_kpt_preds = kpt_preds[pos_mask] * factors + pos_kpt_targets = kpt_targets[pos_mask] * factors + pos_kpt_weights = kpt_weights[pos_mask] + pos_bbox_targets = (bbox_targets[pos_mask].reshape(-1, 2, 2) * + factor).reshape(-1, 4) + + loss_oks = self.loss_oks_aux(pos_kpt_preds, pos_kpt_targets, + pos_kpt_weights, pos_bbox_targets) + else: + loss_oks = torch.zeros_like(loss_kpt) + + return loss_cls, loss_kpt, loss_oks + + def loss_heatmap(self, hm_memory, hm_mask, batch_gt_fields): + + # compute heatmap predition + pred_heatmaps = self.heatmap_fc(hm_memory) + pred_heatmaps = torch.clamp( + pred_heatmaps.sigmoid_(), min=1e-4, max=1 - 1e-4) + pred_heatmaps = pred_heatmaps.permute(0, 3, 1, 2).contiguous() + + # construct heatmap target + gt_heatmaps = torch.zeros_like(pred_heatmaps) + for i, gf in enumerate(batch_gt_fields): + gt_heatmap = gf.gt_heatmaps + h = min(gt_heatmap.size(1), gt_heatmaps.size(2)) + w = min(gt_heatmap.size(2), gt_heatmaps.size(3)) + gt_heatmaps[i, :, :h, :w] = gt_heatmap[:, :h, :w] + + loss_hm = self.loss_hm(pred_heatmaps, gt_heatmaps, None, + 1 - hm_mask.unsqueeze(1).float()) + return loss_hm + + def loss_refined_kpts(self, kpt_preds: Tensor, + batch_img_metas: List[dict]) -> Tuple[Tensor]: + """Loss function for outputs from a single decoder layer of a single + feature level. + + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images, has shape (bs, num_queries, cls_out_channels). + kpt_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape (bs, num_queries, 4). + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and + `loss_iou`. + """ + # kpt_preds [num_selected, num_keypoints, 2] + bbox_targets_list = self._target_buffer['bbox_targets_list'] + kpt_targets_list = self._target_buffer['kpt_targets_list'] + kpt_weights_list = self._target_buffer['kpt_weights_list'] + num_queries = len(kpt_targets_list[0]) + bbox_targets = torch.cat(bbox_targets_list, + 0).contiguous() # [bs*300, 4] (normalized) + kpt_targets = torch.cat(kpt_targets_list, + 0).contiguous() # [bs*300, 17, 2] (normalized) + kpt_weights = torch.cat(kpt_weights_list, + 0).contiguous() # [bs*300, 17] + + pos_mask = (kpt_weights.sum(-1) > 0).contiguous() + pos_inds = (pos_mask.nonzero()).div( + num_queries, rounding_mode='trunc').squeeze(-1) + + # keypoint regression loss + kpt_preds = kpt_preds.reshape(-1, self.num_keypoints, 2) + num_valid_kpt = torch.clamp( + reduce_mean(kpt_weights.sum()), min=1).item() + # assert num_valid_kpt == (kpt_targets>0).sum().item() + loss_kpt = self.loss_reg( + kpt_preds, + kpt_targets[pos_mask], + kpt_weights[pos_mask].unsqueeze(-1), + avg_factor=num_valid_kpt) + + if pos_mask.any().item(): + # construct factors used for rescale keypoints + factors = [] + for img_meta in batch_img_metas: + img_h, img_w, = img_meta['img_shape'] + factor = kpt_preds.new_tensor([img_w, img_h]).reshape(1, 1, 2) + factors.append(factor) + factors = torch.cat(factors, 0) + + factors = factors[pos_inds] + + # keypoint oks loss + pos_kpt_preds = kpt_preds * factors + pos_kpt_targets = kpt_targets[pos_mask] + + pos_kpt_targets = pos_kpt_targets * factors + pos_kpt_weights = kpt_weights[pos_mask] + pos_bbox_targets = (bbox_targets[pos_mask].reshape(-1, 2, 2) * + factor).reshape(-1, 4) + + loss_oks = self.loss_oks_aux(pos_kpt_preds, pos_kpt_targets, + pos_kpt_weights, pos_bbox_targets) + else: + loss_oks = torch.zeros_like(loss_kpt) + + return loss_kpt, loss_oks + + @torch.no_grad() + def get_targets(self, + cls_scores_list: List[Tensor], + kpt_preds_list: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + cache_targets: bool = False) -> tuple: + """Compute regression and classification targets for a batch image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image, has shape [num_queries, + cls_out_channels]. + kpt_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (cx, cy, w, h) and shape [num_queries, 4]. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + + Returns: + tuple: a tuple containing the following targets. + + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all images. + - bbox_targets_list (list[Tensor]): BBox targets for all images. + - bbox_weights_list (list[Tensor]): BBox weights for all images. + - num_total_pos (int): Number of positive samples in all images. + - num_total_neg (int): Number of negative samples in all images. + """ + (labels_list, label_weights_list, bbox_targets_list, kpt_targets_list, + kpt_weights_list, pos_inds_list, + neg_inds_list) = multi_apply(self._get_targets_single, + cls_scores_list, kpt_preds_list, + batch_gt_instances, batch_img_metas) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + + if cache_targets: + self._target_buffer['labels_list'] = labels_list + self._target_buffer['label_weights_list'] = label_weights_list + self._target_buffer['bbox_targets_list'] = bbox_targets_list + self._target_buffer['kpt_targets_list'] = kpt_targets_list + self._target_buffer['kpt_weights_list'] = kpt_weights_list + self._target_buffer['num_total_pos'] = num_total_pos + self._target_buffer['num_total_neg'] = num_total_neg + + return (labels_list, label_weights_list, bbox_targets_list, + kpt_targets_list, kpt_weights_list, num_total_pos, + num_total_neg) + + def _get_targets_single(self, cls_score: Tensor, kpt_pred: Tensor, + gt_instances: InstanceData, + img_meta: dict) -> tuple: + """Compute regression and classification targets for one image. + + Outputs from a single decoder layer of a single feature level are used. + + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_queries, cls_out_channels]. + kpt_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_queries, 4]. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It should includes ``bboxes`` and ``labels`` + attributes. + img_meta (dict): Meta information for one image. + + Returns: + tuple[Tensor]: a tuple containing the following for one image. + + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + img_h, img_w = img_meta['img_shape'] + num_insts = kpt_pred.size(0) + factor = kpt_pred.new_tensor([img_w, img_h]).unsqueeze(0).unsqueeze(1) + kpt_pred = kpt_pred * factor + + pred_instances = InstanceData(scores=cls_score, keypoints=kpt_pred) + # assigner and sampler + assign_result = self.assigner.assign( + pred_instances=pred_instances, + gt_instances=gt_instances, + img_meta=img_meta) + + gt_keypoints = gt_instances.keypoints + gt_keypoints_visible = gt_instances.keypoints_visible + gt_labels = gt_instances.labels + gt_bboxes = gt_instances.bboxes + pos_inds = torch.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + # label targets + labels = gt_labels.new_full((num_insts, ), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[pos_assigned_gt_inds] + label_weights = gt_labels.new_ones(num_insts) + + # kpt targets + kpt_targets = torch.zeros_like(kpt_pred) + pos_gt_keypoints = gt_keypoints[pos_assigned_gt_inds.long(), :] + kpt_targets[pos_inds] = pos_gt_keypoints / factor + kpt_weights = torch.zeros_like(kpt_pred).narrow(-1, 0, 1).squeeze(-1) + pos_gt_keypoints_visible = gt_keypoints_visible[ + pos_assigned_gt_inds.long()] + kpt_weights[pos_inds] = (pos_gt_keypoints_visible > 0).float() + + # bbox_targets, which is used to compute oks loss + bbox_targets = torch.zeros_like(kpt_pred).narrow(-2, 0, 2) + pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds.long()] + bbox_targets[pos_inds] = pos_gt_bboxes.reshape( + *pos_gt_bboxes.shape[:-1], 2, 2) / factor + bbox_targets = bbox_targets.flatten(-2) + + return (labels, label_weights, bbox_targets, kpt_targets, kpt_weights, + pos_inds, neg_inds) diff --git a/projects/petr/models/transformers.py b/projects/petr/models/transformers.py new file mode 100644 index 0000000000..8f02570a1f --- /dev/null +++ b/projects/petr/models/transformers.py @@ -0,0 +1,254 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Tuple, no_type_check + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention +from mmcv.ops import MultiScaleDeformableAttention +from mmcv.ops.multi_scale_deform_attn import ( + MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch) +from mmdet.models.layers import (DeformableDetrTransformerDecoder, + DeformableDetrTransformerDecoderLayer) +from mmdet.models.layers.transformer.utils import inverse_sigmoid +from mmengine.model import ModuleList +from mmengine.utils import deprecated_api_warning +from torch import Tensor, nn + + +class MultiScaleDeformablePoseAttention(MultiScaleDeformableAttention): + + @no_type_check + @deprecated_api_warning({'residual': 'identity'}, + cls_name='MultiScaleDeformablePoseAttention') + def forward(self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + value: Optional[torch.Tensor] = None, + identity: Optional[torch.Tensor] = None, + query_pos: Optional[torch.Tensor] = None, + key_padding_mask: Optional[torch.Tensor] = None, + reference_points: Optional[torch.Tensor] = None, + spatial_shapes: Optional[torch.Tensor] = None, + level_start_index: Optional[torch.Tensor] = None, + **kwargs) -> torch.Tensor: + """Forward Function of MultiScaleDeformAttention. + + Args: + query (torch.Tensor): Query of Transformer with shape + (num_query, bs, embed_dims). + key (torch.Tensor): The key tensor with shape + `(num_key, bs, embed_dims)`. + value (torch.Tensor): The value tensor with shape + `(num_key, bs, embed_dims)`. + identity (torch.Tensor): The tensor used for addition, with the + same shape as `query`. Default None. If None, + `query` will be used. + query_pos (torch.Tensor): The positional encoding for `query`. + Default: None. + key_padding_mask (torch.Tensor): ByteTensor for `query`, with + shape [bs, num_key]. + reference_points (torch.Tensor): The normalized reference + points with shape (bs, num_query, num_levels, 2), + all elements is range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area. + or (N, Length_{query}, num_levels, 4), add + additional two dimensions is (w, h) to + form reference boxes. + spatial_shapes (torch.Tensor): Spatial shape of features in + different levels. With shape (num_levels, 2), + last dimension represents (h, w). + level_start_index (torch.Tensor): The start index of each level. + A tensor has shape ``(num_levels, )`` and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + + Returns: + torch.Tensor: forwarded results with shape + [num_query, bs, embed_dims]. + """ + + if value is None: + value = query + + if identity is None: + identity = query + if query_pos is not None: + query = query + query_pos + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = query.permute(1, 0, 2) + value = value.permute(1, 0, 2) + + bs, num_query, _ = query.shape + bs, num_value, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + + value = self.value_proj(value) + if key_padding_mask is not None: + value = value.masked_fill(key_padding_mask[..., None], 0.0) + value = value.view(bs, num_value, self.num_heads, -1) + sampling_offsets = self.sampling_offsets(query).view( + bs, num_query, self.num_heads, self.num_levels, self.num_points, 2) + attention_weights = self.attention_weights(query).view( + bs, num_query, self.num_heads, self.num_levels * self.num_points) + attention_weights = attention_weights.softmax(-1) + + attention_weights = attention_weights.view(bs, num_query, + self.num_heads, + self.num_levels, + self.num_points) + if reference_points.shape[-1] == self.num_points * 2: + reference_points_reshape = reference_points.reshape( + bs, num_query, self.num_levels, -1, 2).unsqueeze(2) + x1 = reference_points[:, :, :, 0::2].min(dim=-1, keepdim=True)[0] + y1 = reference_points[:, :, :, 1::2].min(dim=-1, keepdim=True)[0] + x2 = reference_points[:, :, :, 0::2].max(dim=-1, keepdim=True)[0] + y2 = reference_points[:, :, :, 1::2].max(dim=-1, keepdim=True)[0] + w = torch.clamp(x2 - x1, min=1e-4) + h = torch.clamp(y2 - y1, min=1e-4) + wh = torch.cat([w, h], dim=-1)[:, :, None, :, None, :] + sampling_locations = reference_points_reshape \ + + sampling_offsets * wh * 0.5 + + else: + raise ValueError( + f'Last dim of reference_points must be {self.num_points*2}, ' + f'but get {reference_points.shape[-1]} instead.') + if torch.cuda.is_available() and value.is_cuda: + output = MultiScaleDeformableAttnFunction.apply( + value, spatial_shapes, level_start_index, sampling_locations, + attention_weights, self.im2col_step) + else: + output = multi_scale_deformable_attn_pytorch( + value, spatial_shapes, sampling_locations, attention_weights) + + output = self.output_proj(output) + + if not self.batch_first: + # (num_query, bs ,embed_dims) + output = output.permute(1, 0, 2) + + return self.dropout(output) + identity + + +class PetrTransformerDecoderLayer(DeformableDetrTransformerDecoderLayer): + + def _init_layers(self) -> None: + """Initialize self_attn, cross-attn, ffn, and norms.""" + self.self_attn = MultiheadAttention(**self.self_attn_cfg) + self.cross_attn = MultiScaleDeformablePoseAttention( + **self.cross_attn_cfg) + self.embed_dims = self.self_attn.embed_dims + self.ffn = FFN(**self.ffn_cfg) + norms_list = [ + build_norm_layer(self.norm_cfg, self.embed_dims)[1] + for _ in range(3) + ] + self.norms = ModuleList(norms_list) + + +class PetrTransformerDecoder(DeformableDetrTransformerDecoder): + + def __init__(self, num_keypoints: int, *args, **kwargs): + self.num_keypoints = num_keypoints + super().__init__(*args, **kwargs) + + def _init_layers(self) -> None: + """Initialize decoder layers.""" + self.layers = ModuleList([ + PetrTransformerDecoderLayer(**self.layer_cfg) + for _ in range(self.num_layers) + ]) + self.embed_dims = self.layers[0].embed_dims + if self.post_norm_cfg is not None: + raise ValueError('There is not post_norm in ' + f'{self._get_name()}') + + def forward(self, + query: Tensor, + query_pos: Tensor, + value: Tensor, + key_padding_mask: Tensor, + reference_points: Tensor, + spatial_shapes: Tensor, + level_start_index: Tensor, + valid_ratios: Tensor, + reg_branches: Optional[nn.Module] = None, + **kwargs) -> Tuple[Tensor]: + """Forward function of Transformer decoder. + + Args: + query (Tensor): The input queries, has shape (bs, num_queries, + dim). + query_pos (Tensor): The input positional query, has shape + (bs, num_queries, dim). It will be added to `query` before + forward function. + value (Tensor): The input values, has shape (bs, num_value, dim). + key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn` + input. ByteTensor, has shape (bs, num_value). + reference_points (Tensor): The initial reference, has shape + (bs, num_queries, 4) with the last dimension arranged as + (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has + shape (bs, num_queries, 2) with the last dimension arranged + as (cx, cy). + spatial_shapes (Tensor): Spatial shapes of features in all levels, + has shape (num_levels, 2), last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels, ) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + valid_ratios (Tensor): The ratios of the valid width and the valid + height relative to the width and the height of features in all + levels, has shape (bs, num_levels, 2). + reg_branches: (obj:`nn.ModuleList`, optional): Used for refining + the regression results. Only would be passed when + `with_box_refine` is `True`, otherwise would be `None`. + + Returns: + tuple[Tensor]: Outputs of Deformable Transformer Decoder. + + - output (Tensor): Output embeddings of the last decoder, has + shape (num_queries, bs, embed_dims) when `return_intermediate` + is `False`. Otherwise, Intermediate output embeddings of all + decoder layers, has shape (num_decoder_layers, num_queries, bs, + embed_dims). + - reference_points (Tensor): The reference of the last decoder + layer, has shape (bs, num_queries, 4) when `return_intermediate` + is `False`. Otherwise, Intermediate references of all decoder + layers, has shape (num_decoder_layers, bs, num_queries, 4). The + coordinates are arranged as (cx, cy, w, h) + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for layer_id, layer in enumerate(self.layers): + assert reference_points.shape[-1] == self.num_keypoints * 2 + reference_points_input = \ + reference_points[:, :, None] * \ + valid_ratios.repeat(1, 1, self.num_keypoints)[:, None] + + output = layer( + output, + query_pos=query_pos, + value=value, + key_padding_mask=key_padding_mask, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reference_points=reference_points_input, + **kwargs) + + if reg_branches is not None: + tmp_reg_preds = reg_branches[layer_id](output) + new_reference_points = tmp_reg_preds + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return output, reference_points From 3a32890423840d68e0fcfc895f2f402a3523e259 Mon Sep 17 00:00:00 2001 From: lupeng Date: Tue, 18 Apr 2023 20:54:43 +0800 Subject: [PATCH 02/12] update --- .../petr/configs/petr_r50_4xb8-100e_coco.py | 26 +++++++++++++++++++ .../configs/petr_r50_4xb8-100e_coco_param1.py | 26 +++++++++++++++++++ .../petr/configs/petr_r50_8xb4-100e_coco.py | 7 ++--- projects/petr/datasets/transforms.py | 7 +++++ projects/petr/models/match_costs.py | 15 +++++++---- projects/petr/models/petr.py | 3 ++- 6 files changed, 75 insertions(+), 9 deletions(-) create mode 100644 projects/petr/configs/petr_r50_4xb8-100e_coco.py create mode 100644 projects/petr/configs/petr_r50_4xb8-100e_coco_param1.py diff --git a/projects/petr/configs/petr_r50_4xb8-100e_coco.py b/projects/petr/configs/petr_r50_4xb8-100e_coco.py new file mode 100644 index 0000000000..5c5c22d4bb --- /dev/null +++ b/projects/petr/configs/petr_r50_4xb8-100e_coco.py @@ -0,0 +1,26 @@ +_base_ = ['./petr_r50_8xb4-100e_coco.py'] + +model = dict( + bbox_head=dict( + loss_cls=dict(loss_weight=2.0), + loss_reg=dict(type='L1Loss', loss_weight=80.0), + loss_reg_aux=dict(type='L1Loss', loss_weight=70.0), + loss_oks=dict(type='OksLoss', loss_weight=30.0), + loss_oks_aux=dict(type='OksLoss', loss_weight=20.0), + loss_hm=dict(type='mmpose.FocalHeatmapLoss', loss_weight=4.0), + ), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='HungarianAssigner', + match_costs=[ + dict(type='FocalLossCost', weight=2.0), + dict(type='KptL1Cost', weight=70.0), + dict( + type='OksCost', + metainfo='configs/_base_/datasets/coco.py', + weight=70.0) + ]))) + + +train_dataloader = dict(batch_size=8) \ No newline at end of file diff --git a/projects/petr/configs/petr_r50_4xb8-100e_coco_param1.py b/projects/petr/configs/petr_r50_4xb8-100e_coco_param1.py new file mode 100644 index 0000000000..c496edde5e --- /dev/null +++ b/projects/petr/configs/petr_r50_4xb8-100e_coco_param1.py @@ -0,0 +1,26 @@ +_base_ = ['./petr_r50_8xb4-100e_coco.py'] + +model = dict( + bbox_head=dict( + loss_cls=dict(loss_weight=2.0), + loss_reg=dict(type='L1Loss', loss_weight=8.0), + loss_reg_aux=dict(type='L1Loss', loss_weight=7.0), + loss_oks=dict(type='OksLoss', loss_weight=3.0), + loss_oks_aux=dict(type='OksLoss', loss_weight=2.0), + loss_hm=dict(type='mmpose.FocalHeatmapLoss', loss_weight=2.0), + ), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='HungarianAssigner', + match_costs=[ + dict(type='FocalLossCost', weight=2.0), + dict(type='KptL1Cost', weight=7.0), + dict( + type='OksCost', + metainfo='configs/_base_/datasets/coco.py', + weight=7.0) + ]))) + + +train_dataloader = dict(batch_size=8) \ No newline at end of file diff --git a/projects/petr/configs/petr_r50_8xb4-100e_coco.py b/projects/petr/configs/petr_r50_8xb4-100e_coco.py index 36a1434344..ad52daf802 100644 --- a/projects/petr/configs/petr_r50_8xb4-100e_coco.py +++ b/projects/petr/configs/petr_r50_8xb4-100e_coco.py @@ -134,8 +134,8 @@ gamma=2.0, alpha=0.25, loss_weight=2.0), - loss_reg=dict(type='L1Loss', loss_weight=80.0), - loss_reg_aux=dict(type='L1Loss', loss_weight=70.0), + loss_reg=dict(type='L1Loss', loss_weight=40.0), + loss_reg_aux=dict(type='L1Loss', loss_weight=35.0), loss_oks=dict( type='OksLoss', metainfo='configs/_base_/datasets/coco.py', @@ -174,6 +174,7 @@ # scaling_ratio_range=(1., 1.), # max_shear_degree=0., scaling_ratio_range=(0.75, 1.0), + border_val=[103.53, 116.28, 123.675], ), dict(type='RandomFlip', prob=0.5), dict( @@ -224,7 +225,7 @@ data_root = 'data/coco/' train_dataloader = dict( - batch_size=2, + batch_size=4, num_workers=2, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), diff --git a/projects/petr/datasets/transforms.py b/projects/petr/datasets/transforms.py index 82fd5b0f8c..5262de83ff 100644 --- a/projects/petr/datasets/transforms.py +++ b/projects/petr/datasets/transforms.py @@ -85,6 +85,13 @@ def transform(self, results: dict) -> dict: gt_fields.set_field(results[key], packed_key) + # Ensure all keys in `self.meta_keys` are in the `results` dictionary, + # which is necessary for `PackDetInputs` but not guaranteed during + # inference with an inferencer + for key in self.meta_keys: + if key not in results: + results[key] = None + results = super().transform(results) if gt_fields: results['data_samples'].gt_fields = gt_fields.to_tensor() diff --git a/projects/petr/models/match_costs.py b/projects/petr/models/match_costs.py index 0930ef9a83..18b350d0b1 100644 --- a/projects/petr/models/match_costs.py +++ b/projects/petr/models/match_costs.py @@ -21,14 +21,19 @@ def __call__(self, pred_keypoints = pred_instances.keypoints gt_keypoints = gt_instances.keypoints - + gt_keypoints_visible = gt_instances.keypoints_visible + gt_keypoints_visible = gt_keypoints_visible / (2 * gt_keypoints_visible.sum(dim=1, keepdim=True) + 1e-8) + # normalized img_h, img_w = img_meta['img_shape'] factor = gt_keypoints.new_tensor([img_w, img_h]).reshape(1, 1, 2) - gt_keypoints = (gt_keypoints / factor).flatten(1) - pred_keypoints = (pred_keypoints / factor).flatten(1) - - kpt_cost = torch.cdist(pred_keypoints, gt_keypoints, p=1) + gt_keypoints = (gt_keypoints / factor).unsqueeze(0) + gt_keypoints_visible = gt_keypoints_visible.unsqueeze(0).unsqueeze(-1) + pred_keypoints = (pred_keypoints / factor).unsqueeze(1) + + diff = (pred_keypoints - gt_keypoints) * gt_keypoints_visible + kpt_cost = diff.flatten(2).norm(dim=2, p=1) + return kpt_cost * self.weight diff --git a/projects/petr/models/petr.py b/projects/petr/models/petr.py index 11de5a0d57..c94794fb6d 100644 --- a/projects/petr/models/petr.py +++ b/projects/petr/models/petr.py @@ -383,7 +383,6 @@ def forward_encoder(self, hm_memory = self.hm_encoder( query=hm_memory, query_pos=hm_pos_embed, - query_pos=None, key_padding_mask=hm_mask, spatial_shapes=spatial_shapes.narrow(0, 0, 1), level_start_index=level_start_index[0], @@ -685,6 +684,8 @@ def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, The hook will be automatically registered during initialization. """ + return + if 'mmengine_version' in local_meta: return From 0ff2e392771a3f677faa80807a17778d07077835 Mon Sep 17 00:00:00 2001 From: lupeng Date: Thu, 20 Apr 2023 02:19:52 +0800 Subject: [PATCH 03/12] update --- projects/petr/datasets/transforms.py | 2 +- projects/petr/models/petr_head.py | 52 +++++++++++++--------------- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/projects/petr/datasets/transforms.py b/projects/petr/datasets/transforms.py index 5262de83ff..17a3e7709e 100644 --- a/projects/petr/datasets/transforms.py +++ b/projects/petr/datasets/transforms.py @@ -143,7 +143,7 @@ def _get_instance_wise_sigmas(self, sq3 = np.sqrt(b3**2 - 4 * a3 * c3) r3 = (b3 + sq3) / 2 - sigmas[i] = min(r1, r2, r3) / 3 + sigmas[i] = min(r1, r2, r3, 3) / 3 return sigmas diff --git a/projects/petr/models/petr_head.py b/projects/petr/models/petr_head.py index 53c297b1b8..d45bdace1a 100644 --- a/projects/petr/models/petr_head.py +++ b/projects/petr/models/petr_head.py @@ -282,8 +282,7 @@ def loss(self, enc_outputs_class: Tensor, enc_outputs_coord: Tensor, all_layers_classes, all_layers_coords, batch_gt_instances=batch_gt_instances, - batch_img_metas=batch_img_metas, - cache_targets=True) + batch_img_metas=batch_img_metas) # loss from the last decoder layer loss_dict['loss_cls'] = losses_cls[-1] loss_dict['loss_kpt'] = losses_kpt[-1] @@ -299,12 +298,11 @@ def loss(self, enc_outputs_class: Tensor, enc_outputs_coord: Tensor, num_dec_layer += 1 # calculate loss for encoder output - losses_cls, losses_kpt, losses_oks = self.loss_by_feat_single( + losses_cls, losses_kpt, _ = self.loss_by_feat_single( enc_outputs_class, enc_outputs_coord, batch_gt_instances=batch_gt_instances, batch_img_metas=batch_img_metas, - cache_targets=False, compute_oks_loss=False) loss_dict['enc_loss_cls'] = losses_cls loss_dict['enc_loss_kpt'] = losses_kpt @@ -336,7 +334,6 @@ def loss_by_feat_single(self, kpt_preds: Tensor, batch_gt_instances: InstanceList, batch_img_metas: List[dict], - cache_targets: bool = False, compute_oks_loss: bool = True) -> Tuple[Tensor]: """Loss function for outputs from a single decoder layer of a single feature level. @@ -399,7 +396,6 @@ def loss_by_feat_single(self, cls_avg_factor = reduce_mean( cls_scores.new_tensor([cls_avg_factor])) cls_avg_factor = max(cls_avg_factor, 1) - loss_cls = self.loss_cls( cls_scores, labels, label_weights, avg_factor=cls_avg_factor) @@ -422,7 +418,7 @@ def loss_by_feat_single(self, pos_kpt_targets = kpt_targets[pos_mask] * factors pos_kpt_weights = kpt_weights[pos_mask] pos_bbox_targets = (bbox_targets[pos_mask].reshape(-1, 2, 2) * - factor).reshape(-1, 4) + factors).flatten(-2) loss_oks = self.loss_oks_aux(pos_kpt_preds, pos_kpt_targets, pos_kpt_weights, pos_bbox_targets) @@ -431,26 +427,6 @@ def loss_by_feat_single(self, return loss_cls, loss_kpt, loss_oks - def loss_heatmap(self, hm_memory, hm_mask, batch_gt_fields): - - # compute heatmap predition - pred_heatmaps = self.heatmap_fc(hm_memory) - pred_heatmaps = torch.clamp( - pred_heatmaps.sigmoid_(), min=1e-4, max=1 - 1e-4) - pred_heatmaps = pred_heatmaps.permute(0, 3, 1, 2).contiguous() - - # construct heatmap target - gt_heatmaps = torch.zeros_like(pred_heatmaps) - for i, gf in enumerate(batch_gt_fields): - gt_heatmap = gf.gt_heatmaps - h = min(gt_heatmap.size(1), gt_heatmaps.size(2)) - w = min(gt_heatmap.size(2), gt_heatmaps.size(3)) - gt_heatmaps[i, :, :h, :w] = gt_heatmap[:, :h, :w] - - loss_hm = self.loss_hm(pred_heatmaps, gt_heatmaps, None, - 1 - hm_mask.unsqueeze(1).float()) - return loss_hm - def loss_refined_kpts(self, kpt_preds: Tensor, batch_img_metas: List[dict]) -> Tuple[Tensor]: """Loss function for outputs from a single decoder layer of a single @@ -517,7 +493,7 @@ def loss_refined_kpts(self, kpt_preds: Tensor, pos_kpt_targets = pos_kpt_targets * factors pos_kpt_weights = kpt_weights[pos_mask] pos_bbox_targets = (bbox_targets[pos_mask].reshape(-1, 2, 2) * - factor).reshape(-1, 4) + factors).flatten(-2) loss_oks = self.loss_oks_aux(pos_kpt_preds, pos_kpt_targets, pos_kpt_weights, pos_bbox_targets) @@ -526,6 +502,26 @@ def loss_refined_kpts(self, kpt_preds: Tensor, return loss_kpt, loss_oks + def loss_heatmap(self, hm_memory, hm_mask, batch_gt_fields): + + # compute heatmap predition + pred_heatmaps = self.heatmap_fc(hm_memory) + pred_heatmaps = torch.clamp( + pred_heatmaps.sigmoid_(), min=1e-4, max=1 - 1e-4) + pred_heatmaps = pred_heatmaps.permute(0, 3, 1, 2).contiguous() + + # construct heatmap target + gt_heatmaps = torch.zeros_like(pred_heatmaps) + for i, gf in enumerate(batch_gt_fields): + gt_heatmap = gf.gt_heatmaps + h = min(gt_heatmap.size(1), gt_heatmaps.size(2)) + w = min(gt_heatmap.size(2), gt_heatmaps.size(3)) + gt_heatmaps[i, :, :h, :w] = gt_heatmap[:, :h, :w] + + loss_hm = self.loss_hm(pred_heatmaps, gt_heatmaps, None, + 1 - hm_mask.unsqueeze(1).float()) + return loss_hm + @torch.no_grad() def get_targets(self, cls_scores_list: List[Tensor], From 54f8f5bd9f417bad56f257c3deda016c3bb918cb Mon Sep 17 00:00:00 2001 From: lupeng Date: Mon, 29 May 2023 17:57:04 +0800 Subject: [PATCH 04/12] update --- .../petr/configs/petr_r50_8xb4-100e_coco.py | 1 + projects/petr/datasets/transforms.py | 62 ++++++++++++++++--- projects/petr/models/petr.py | 8 ++- projects/petr/models/petr_head.py | 16 +++++ 4 files changed, 77 insertions(+), 10 deletions(-) diff --git a/projects/petr/configs/petr_r50_8xb4-100e_coco.py b/projects/petr/configs/petr_r50_8xb4-100e_coco.py index ad52daf802..c53e3965ce 100644 --- a/projects/petr/configs/petr_r50_8xb4-100e_coco.py +++ b/projects/petr/configs/petr_r50_8xb4-100e_coco.py @@ -204,6 +204,7 @@ keep_ratio=True) ] ]), + dict(type='FilterDetPoseAnnotations', keep_empty=False), dict(type='GenerateHeatmap'), dict( type='PackDetPoseInputs', diff --git a/projects/petr/datasets/transforms.py b/projects/petr/datasets/transforms.py index 17a3e7709e..08fb915316 100644 --- a/projects/petr/datasets/transforms.py +++ b/projects/petr/datasets/transforms.py @@ -7,6 +7,8 @@ from mmdet.registry import TRANSFORMS from mmdet.structures.bbox.box_type import autocast_box_type from mmengine.structures import PixelData +from mmdet.datasets.transforms import FilterAnnotations as FilterDetAnnotations + from mmpose.codecs.utils import generate_gaussian_heatmaps from .bbox_keypoint_structure import BBoxKeypoints @@ -85,13 +87,6 @@ def transform(self, results: dict) -> dict: gt_fields.set_field(results[key], packed_key) - # Ensure all keys in `self.meta_keys` are in the `results` dictionary, - # which is necessary for `PackDetInputs` but not guaranteed during - # inference with an inferencer - for key in self.meta_keys: - if key not in results: - results[key] = None - results = super().transform(results) if gt_fields: results['data_samples'].gt_fields = gt_fields.to_tensor() @@ -143,7 +138,7 @@ def _get_instance_wise_sigmas(self, sq3 = np.sqrt(b3**2 - 4 * a3 * c3) r3 = (b3 + sq3) / 2 - sigmas[i] = min(r1, r2, r3, 3) / 3 + sigmas[i] = min(r1, r2, r3) / 3 return sigmas @@ -174,3 +169,54 @@ def transform(self, results: dict) -> Union[dict, None]: results['gt_heatmaps'] = hm return results + + +@TRANSFORMS.register_module() +class FilterDetPoseAnnotations(FilterDetAnnotations): + """Filter invalid annotations. + + In addition to the conditions checked by ``FilterDetAnnotations``, this + filter adds a new condition requiring instances to have at least one + visible keypoints. + """ + + @autocast_box_type() + def transform(self, results: dict) -> Union[dict, None]: + """Transform function to filter annotations. + + Args: + results (dict): Result dict. + + Returns: + dict: Updated result dict. + """ + assert 'gt_bboxes' in results + gt_bboxes = results['gt_bboxes'] + if gt_bboxes.shape[0] == 0: + return results + + tests = [] + if self.by_box: + tests.append(((gt_bboxes.widths > self.min_gt_bbox_wh[0]) & + (gt_bboxes.heights > self.min_gt_bbox_wh[1]) & + (gt_bboxes.num_keypoints > 0)).numpy()) + + if self.by_mask: + assert 'gt_masks' in results + gt_masks = results['gt_masks'] + tests.append(gt_masks.areas >= self.min_gt_mask_area) + + keep = tests[0] + for t in tests[1:]: + keep = keep & t + + if not keep.any(): + if self.keep_empty: + return None + + keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags') + for key in keys: + if key in results: + results[key] = results[key][keep] + + return results diff --git a/projects/petr/models/petr.py b/projects/petr/models/petr.py index c94794fb6d..b2c7ce667b 100644 --- a/projects/petr/models/petr.py +++ b/projects/petr/models/petr.py @@ -13,10 +13,10 @@ from mmdet.registry import MODELS from mmdet.structures import OptSampleList, SampleList from mmengine.model import xavier_init -from torch import Tensor, nn +from torch import Tensor, nnat from torch.nn.init import normal_ -from .transformers import PetrTransformerDecoder +from .transformers import PetrTransformerDecoder, MultiScaleDeformablePoseAttention @MODELS.register_module() @@ -77,12 +77,16 @@ def init_weights(self) -> None: for m in self.modules(): if isinstance(m, MultiScaleDeformableAttention): m.init_weights() + for m in self.modules(): + if isinstance(m, MultiScaleDeformablePoseAttention): + m.init_weights() if self.as_two_stage: nn.init.xavier_uniform_(self.memory_trans_fc.weight) else: xavier_init( self.reference_points_fc, distribution='uniform', bias=0.) normal_(self.level_embed) + normal_(self.kpt_query_embedding.weight) def forward_transformer(self, img_feats: Tuple[Tensor], diff --git a/projects/petr/models/petr_head.py b/projects/petr/models/petr_head.py index d45bdace1a..6c03d15910 100644 --- a/projects/petr/models/petr_head.py +++ b/projects/petr/models/petr_head.py @@ -4,6 +4,8 @@ import torch import torch.nn as nn +from mmengine.model import bias_init_with_prob, constant_init, normal_init + from mmcv.cnn import Linear from mmdet.models import inverse_sigmoid from mmdet.models.dense_heads import DeformableDETRHead @@ -78,6 +80,20 @@ def _init_layers(self) -> None: self.heatmap_fc = Linear(self.embed_dims, self.num_keypoints) + def init_weights(self) -> None: + """Initialize weights of the Deformable DETR head.""" + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + for m in self.cls_branches: + nn.init.constant_(m.bias, bias_init) + for m in self.reg_branches: + constant_init(m[-1], 0, bias=0) + for m in self.kpt_branches: + constant_init(m[-1], 0, bias=0) + # initialize bias for heatmap prediction + bias_init = bias_init_with_prob(0.1) + normal_init(self.heatmap_fc, std=0.01, bias=bias_init) + def forward(self, hidden_states: Tensor, references: List[Tensor]) -> Tuple[Tensor]: """Forward function. From d3af7cd05e170c3cd1582bf830a3568fe9491ebb Mon Sep 17 00:00:00 2001 From: lupeng Date: Wed, 31 May 2023 10:58:31 +0800 Subject: [PATCH 05/12] fix oks loss --- projects/petr/configs/petr_r50_8xb4-100e_coco.py | 6 +++--- projects/petr/models/losses.py | 9 ++++----- projects/petr/models/petr.py | 7 ++++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/projects/petr/configs/petr_r50_8xb4-100e_coco.py b/projects/petr/configs/petr_r50_8xb4-100e_coco.py index c53e3965ce..60ecc7be4e 100644 --- a/projects/petr/configs/petr_r50_8xb4-100e_coco.py +++ b/projects/petr/configs/petr_r50_8xb4-100e_coco.py @@ -134,8 +134,8 @@ gamma=2.0, alpha=0.25, loss_weight=2.0), - loss_reg=dict(type='L1Loss', loss_weight=40.0), - loss_reg_aux=dict(type='L1Loss', loss_weight=35.0), + loss_reg=dict(type='L1Loss', loss_weight=80.0), + loss_reg_aux=dict(type='L1Loss', loss_weight=70.0), loss_oks=dict( type='OksLoss', metainfo='configs/_base_/datasets/coco.py', @@ -227,7 +227,7 @@ train_dataloader = dict( batch_size=4, - num_workers=2, + num_workers=4, persistent_workers=True, sampler=dict(type='DefaultSampler', shuffle=True), batch_sampler=dict(type='AspectRatioBatchSampler'), diff --git a/projects/petr/models/losses.py b/projects/petr/models/losses.py index 8e1601e51f..881ae39c1c 100644 --- a/projects/petr/models/losses.py +++ b/projects/petr/models/losses.py @@ -48,7 +48,7 @@ def forward(self, target_weights: Tensor, bboxes: Optional[Tensor] = None) -> Tensor: oks = self.compute_oks(output, target, target_weights, bboxes) - loss = 1 - oks + loss = -torch.log(oks.clamp(min=1e-6)) return loss.mean() * self.loss_weight def compute_oks(self, @@ -79,11 +79,10 @@ def compute_oks(self, sigmas = self.sigmas.reshape(*((1, ) * (dist.ndim - 1)), -1) if sigmas.device != dist.device: sigmas = sigmas.to(dist.device) - dist = dist / sigmas + dist = dist / (sigmas * 2) if bboxes is not None: - area = torch.prod( - bboxes[..., 2:] - bboxes[..., :2], dim=-1).pow(0.5) - dist = dist / area.clip(min=1e-8).unsqueeze(-1) + area = torch.prod(bboxes[..., 2:] - bboxes[..., :2], dim=-1) * 0.53 + dist = dist / area.pow(0.5).clip(min=1e-8).unsqueeze(-1) return (torch.exp(-dist.pow(2) / 2) * target_weights).sum( dim=-1) / target_weights.sum(dim=-1).clip(min=1e-8) diff --git a/projects/petr/models/petr.py b/projects/petr/models/petr.py index b2c7ce667b..a570248890 100644 --- a/projects/petr/models/petr.py +++ b/projects/petr/models/petr.py @@ -13,10 +13,11 @@ from mmdet.registry import MODELS from mmdet.structures import OptSampleList, SampleList from mmengine.model import xavier_init -from torch import Tensor, nnat +from torch import Tensor, nn from torch.nn.init import normal_ -from .transformers import PetrTransformerDecoder, MultiScaleDeformablePoseAttention +from .transformers import (MultiScaleDeformablePoseAttention, + PetrTransformerDecoder) @MODELS.register_module() @@ -689,7 +690,7 @@ def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, The hook will be automatically registered during initialization. """ return - + if 'mmengine_version' in local_meta: return From a19c1c05ed6f279ec582199106038a4dc2ed528d Mon Sep 17 00:00:00 2001 From: lupeng Date: Wed, 31 May 2023 16:54:37 +0800 Subject: [PATCH 06/12] refine code --- projects/petr/README.md | 60 +++++ .../petr/configs/_base_/coco_detection.py | 67 ----- .../petr/configs/_base_/default_runtime.py | 2 +- .../petr/configs/petr_r50_4xb8-100e_coco.py | 26 -- .../configs/petr_r50_4xb8-100e_coco_param1.py | 26 -- .../petr/configs/petr_r50_8xb4-100e_coco.py | 6 +- projects/petr/datasets/coco_dataset.py | 4 +- projects/petr/datasets/transforms.py | 26 +- projects/petr/models/losses.py | 9 +- projects/petr/models/match_costs.py | 56 +++- projects/petr/models/petr.py | 253 ++++++++---------- projects/petr/models/petr_head.py | 164 +++--------- projects/petr/models/transformers.py | 2 + 13 files changed, 284 insertions(+), 417 deletions(-) create mode 100644 projects/petr/README.md delete mode 100644 projects/petr/configs/_base_/coco_detection.py delete mode 100644 projects/petr/configs/petr_r50_4xb8-100e_coco.py delete mode 100644 projects/petr/configs/petr_r50_4xb8-100e_coco_param1.py diff --git a/projects/petr/README.md b/projects/petr/README.md new file mode 100644 index 0000000000..fcece2a17f --- /dev/null +++ b/projects/petr/README.md @@ -0,0 +1,60 @@ +# YOLOX-Pose + +This project implements PETR (Pose Estimation with TRansformers), an end-to-end multi-person pose estimation framework introduced in the CVPR 2022 paper **End-to-End Multi-Person Pose Estimation with Transformers**. PETR is a novel, end-to-end multi-person pose estimation method that treats pose estimation as a hierarchical set prediction problem. By leveraging attention mechanisms, PETR can adaptively focus on features most relevant to target keypoints, thereby overcoming feature misalignment issues in pose estimation. + +
+ +## Usage + +### Prerequisites + +- Python 3.7 or higher +- PyTorch 1.6 or higher +- [MMEngine](https://github.com/open-mmlab/mmengine) v0.7.0 or higher +- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0 or higher +- [MMDetection](https://github.com/open-mmlab/mmdetection) v3.0.0 or higher +- [MMPose](https://github.com/open-mmlab/mmpose) v1.0.0 or higher + +All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. **In `petr/` root directory**, run the following line to add the current directory to `PYTHONPATH`: + +```shell +export PYTHONPATH=`pwd`:$PYTHONPATH +``` + +### Inference + +Users can apply YOLOX-Pose models to estimate human poses using the inferencer found in the MMPose core package. Use the command below: + +```shell +python demo/inferencer_demo.py $INPUTS \ + --pose2d $CONFIG --pose2d-weights $CHECKPOINT --scope mmdet \ + [--show] [--vis-out-dir $VIS_OUT_DIR] [--pred-out-dir $PRED_OUT_DIR] +``` + +For more information on using the inferencer, please see [this document](https://mmpose.readthedocs.io/en/latest/user_guides/inference.html#out-of-the-box-inferencer). + +### Results + +Results on COCO val2017 + +| Model | Backbone | Lr schd | mAP | AP50 | AP75 | AR | AR50 | Config | Download | +| :---: | :------: | :-----: | :--: | :-------------: | :-------------: | :--: | :-------------: | :----------------------------------------------: | :-----------------------------------------------------------------------------: | +| PETR | R-50 | 100e | 68.7 | 87.5 | 76.2 | 75.9 | 92.1 | [config](/configs/petr_r50_8xb4-100e_coco.py) | [Google Drive](https://drive.google.com/file/d/1HcwraqWdZ3CaGMQOJHY8exNem7UnFkfS/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1C0HbQWV7K-GHQE7q34nUZw?pwd=u798) | +| PETR | R-101 | 100e | 70.0 | 88.5 | 77.5 | 77.0 | 92.6 | [config](/configs/petr_r101_8xb4-100e_coco.py) | [Google Drive](https://drive.google.com/file/d/1O261Jrt4JRGlIKTmLtPy3AUruwX1hsDf/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1D5wqNP53KNOKKE5NnO2Dnw?pwd=keyn) | +| PETR | Swin-L | 100e | 73.0 | 90.7 | 80.9 | 80.1 | 94.5 | [config](/configs/petr_swin-l_8xb4-100e_coco.py) | [Google Drive](https://drive.google.com/file/d/1ujL0Gm5tPjweT0-gdDGkTc7xXrEt6gBP/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1X5Cdq75GosRCKqbHZTSpJQ?pwd=t9ea) | + +Currently, the PETR implemented in this project supports inference using the official checkpoint. However, the training accuracy is still not up to the results reported in the paper. We will continue to update this project after aligning the training accuracy. + +## Citation + +If this project benefits your work, please kindly consider citing the original papers: + +```bibtex +@inproceedings{shi2022end, + title={End-to-end multi-person pose estimation with transformers}, + author={Shi, Dahu and Wei, Xing and Li, Liangqi and Ren, Ye and Tan, Wenming}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + pages={11069--11078}, + year={2022} +} +``` diff --git a/projects/petr/configs/_base_/coco_detection.py b/projects/petr/configs/_base_/coco_detection.py deleted file mode 100644 index 1761a0a3cb..0000000000 --- a/projects/petr/configs/_base_/coco_detection.py +++ /dev/null @@ -1,67 +0,0 @@ -# dataset settings -dataset_type = 'mmpose.CocoDataset' -data_mode = 'bottomup' -data_root = 'data/coco/' - -# file_client_args = dict( -# backend='petrel', -# path_mapping=dict({ -# './data/': 's3://openmmlab/datasets/detection/', -# 'data/': 's3://openmmlab/datasets/detection/' -# })) -file_client_args = dict(backend='disk') - -train_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), - dict(type='LoadAnnotations', with_bbox=True), - dict(type='Resize', scale=(1333, 800), keep_ratio=True), - dict(type='RandomFlip', prob=0.5), - dict(type='PackDetInputs') -] - -test_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=file_client_args), - dict(type='Resize', scale=(1333, 800), keep_ratio=True), - # If you don't have a gt annotation, delete the pipeline - dict(type='LoadAnnotations', with_bbox=True), - dict( - type='PackDetInputs', - meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', - 'scale_factor')) -] -train_dataloader = dict( - batch_size=2, - num_workers=2, - persistent_workers=True, - sampler=dict(type='DefaultSampler', shuffle=True), - batch_sampler=dict(type='AspectRatioBatchSampler'), - dataset=dict( - type=dataset_type, - data_root=data_root, - ann_file='annotations/person_keypoints_train2017.json', - data_prefix=dict(img='train2017/'), - filter_cfg=dict(filter_empty_gt=True, min_size=32), - pipeline=train_pipeline)) - -val_dataloader = dict( - batch_size=1, - num_workers=2, - persistent_workers=True, - drop_last=False, - sampler=dict(type='DefaultSampler', shuffle=False), - dataset=dict( - type=dataset_type, - data_root=data_root, - ann_file='annotations/person_keypoints_val2017.json', - data_prefix=dict(img='val2017/'), - test_mode=True, - pipeline=test_pipeline)) -test_dataloader = val_dataloader - -val_evaluator = dict( - type='CocoMetric', - ann_file=data_root + 'annotations/person_keypoints_val2017.json', -) -test_evaluator = val_evaluator - -default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) diff --git a/projects/petr/configs/_base_/default_runtime.py b/projects/petr/configs/_base_/default_runtime.py index 7a2a84af27..fb9e38577c 100644 --- a/projects/petr/configs/_base_/default_runtime.py +++ b/projects/petr/configs/_base_/default_runtime.py @@ -33,7 +33,7 @@ resume = False # file I/O backend -file_client_args = dict(backend='disk') +backend_args = dict(backend='local') # training/validation/testing progress train_cfg = dict() diff --git a/projects/petr/configs/petr_r50_4xb8-100e_coco.py b/projects/petr/configs/petr_r50_4xb8-100e_coco.py deleted file mode 100644 index 5c5c22d4bb..0000000000 --- a/projects/petr/configs/petr_r50_4xb8-100e_coco.py +++ /dev/null @@ -1,26 +0,0 @@ -_base_ = ['./petr_r50_8xb4-100e_coco.py'] - -model = dict( - bbox_head=dict( - loss_cls=dict(loss_weight=2.0), - loss_reg=dict(type='L1Loss', loss_weight=80.0), - loss_reg_aux=dict(type='L1Loss', loss_weight=70.0), - loss_oks=dict(type='OksLoss', loss_weight=30.0), - loss_oks_aux=dict(type='OksLoss', loss_weight=20.0), - loss_hm=dict(type='mmpose.FocalHeatmapLoss', loss_weight=4.0), - ), - # training and testing settings - train_cfg=dict( - assigner=dict( - type='HungarianAssigner', - match_costs=[ - dict(type='FocalLossCost', weight=2.0), - dict(type='KptL1Cost', weight=70.0), - dict( - type='OksCost', - metainfo='configs/_base_/datasets/coco.py', - weight=70.0) - ]))) - - -train_dataloader = dict(batch_size=8) \ No newline at end of file diff --git a/projects/petr/configs/petr_r50_4xb8-100e_coco_param1.py b/projects/petr/configs/petr_r50_4xb8-100e_coco_param1.py deleted file mode 100644 index c496edde5e..0000000000 --- a/projects/petr/configs/petr_r50_4xb8-100e_coco_param1.py +++ /dev/null @@ -1,26 +0,0 @@ -_base_ = ['./petr_r50_8xb4-100e_coco.py'] - -model = dict( - bbox_head=dict( - loss_cls=dict(loss_weight=2.0), - loss_reg=dict(type='L1Loss', loss_weight=8.0), - loss_reg_aux=dict(type='L1Loss', loss_weight=7.0), - loss_oks=dict(type='OksLoss', loss_weight=3.0), - loss_oks_aux=dict(type='OksLoss', loss_weight=2.0), - loss_hm=dict(type='mmpose.FocalHeatmapLoss', loss_weight=2.0), - ), - # training and testing settings - train_cfg=dict( - assigner=dict( - type='HungarianAssigner', - match_costs=[ - dict(type='FocalLossCost', weight=2.0), - dict(type='KptL1Cost', weight=7.0), - dict( - type='OksCost', - metainfo='configs/_base_/datasets/coco.py', - weight=7.0) - ]))) - - -train_dataloader = dict(batch_size=8) \ No newline at end of file diff --git a/projects/petr/configs/petr_r50_8xb4-100e_coco.py b/projects/petr/configs/petr_r50_8xb4-100e_coco.py index 60ecc7be4e..72ee7dbd78 100644 --- a/projects/petr/configs/petr_r50_8xb4-100e_coco.py +++ b/projects/petr/configs/petr_r50_8xb4-100e_coco.py @@ -18,8 +18,6 @@ ] # NOTE: `auto_scale_lr` is for automatically scaling LR, -# USER SHOULD NOT CHANGE ITS VALUES. -# base_batch_size = (16 GPUs) x (2 samples per GPU) auto_scale_lr = dict(base_batch_size=32) # optimizer @@ -164,7 +162,7 @@ )) train_pipeline = [ - dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args), + dict(type='mmpose.LoadImage', backend_args=_base_.backend_args), dict(type='PoseToDetConverter'), dict(type='PhotoMetricDistortion'), dict( @@ -212,7 +210,7 @@ ] test_pipeline = [ - dict(type='LoadImageFromFile'), + dict(type='mmpose.LoadImage', backend_args=_base_.backend_args), dict(type='PoseToDetConverter'), dict(type='Resize', scale=(1333, 800), keep_ratio=True), dict( diff --git a/projects/petr/datasets/coco_dataset.py b/projects/petr/datasets/coco_dataset.py index 99e0d1120f..165aa2f251 100644 --- a/projects/petr/datasets/coco_dataset.py +++ b/projects/petr/datasets/coco_dataset.py @@ -11,7 +11,8 @@ class CocoDataset(MMPoseCocoDataset): def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: - """Parse raw COCO annotation of an instance. + """Parse raw COCO annotation of an instance. Compared with original + implementation, this method add image width and height in `data_info` Args: raw_data_info (dict): Raw data information loaded from @@ -57,6 +58,7 @@ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: data_info = { 'img_id': ann['image_id'], 'img_path': img['img_path'], + # Addition: add image width and height 'width': img_w, 'height': img_h, 'bbox': bbox, diff --git a/projects/petr/datasets/transforms.py b/projects/petr/datasets/transforms.py index 08fb915316..be4c748505 100644 --- a/projects/petr/datasets/transforms.py +++ b/projects/petr/datasets/transforms.py @@ -3,12 +3,11 @@ import numpy as np from mmcv.transforms import BaseTransform +from mmdet.datasets.transforms import FilterAnnotations as FilterDetAnnotations from mmdet.datasets.transforms import PackDetInputs from mmdet.registry import TRANSFORMS from mmdet.structures.bbox.box_type import autocast_box_type from mmengine.structures import PixelData -from mmdet.datasets.transforms import FilterAnnotations as FilterDetAnnotations - from mmpose.codecs.utils import generate_gaussian_heatmaps from .bbox_keypoint_structure import BBoxKeypoints @@ -94,17 +93,26 @@ def transform(self, results: dict) -> dict: return results -@TRANSFORMS.register_module(force=True) +@TRANSFORMS.register_module() class GenerateHeatmap(BaseTransform): + """This class is responsible for creating a heatmap based on bounding boxes + and keypoints. + + It generates a Gaussian heatmap for each instance in the image, given the + bounding box and keypoints. + """ def _get_instance_wise_sigmas(self, bbox: np.ndarray, heatmap_min_overlap: float = 0.9 ) -> np.ndarray: - """Get sigma values for each instance according to their size. + """Compute sigma values for each instance based on their bounding box + size. Args: bbox (np.ndarray): Bounding box in shape (N, 4, 2) + heatmap_min_overlap (float, optional): Minimum overlap for + the heatmap. Defaults to 0.9. Returns: np.ndarray: Array containing the sigma values for each instance. @@ -144,14 +152,12 @@ def _get_instance_wise_sigmas(self, @autocast_box_type() def transform(self, results: dict) -> Union[dict, None]: - """Transform function to filter annotations. - - Args: - results (dict): Result dict. + """Apply the transformation to filter annotations. - Returns: - dict: Updated result dict. + This function rescales bounding boxes and keypoints, and generates a + Gaussian heatmap for each instance. """ + assert 'gt_bboxes' in results bbox = results['gt_bboxes'].tensor.numpy() / 8 diff --git a/projects/petr/models/losses.py b/projects/petr/models/losses.py index 881ae39c1c..279e6f7c6f 100644 --- a/projects/petr/models/losses.py +++ b/projects/petr/models/losses.py @@ -12,9 +12,8 @@ @MODELS.register_module(force=True) class OksLoss(nn.Module): """A PyTorch implementation of the Object Keypoint Similarity (OKS) loss as - described in the paper "YOLO-Pose: Enhancing YOLO for Multi Person Pose - Estimation Using Object Keypoint Similarity Loss" by Debapriya et al. - (2022). + described in the paper "End-to-End Multi-Person Pose Estimation with + Transformers" by Shi et al. (2022). The OKS loss is used for keypoint-based object recognition and consists of a measure of the similarity between predicted and ground truth @@ -56,7 +55,7 @@ def compute_oks(self, target: Tensor, target_weights: Tensor, bboxes: Optional[Tensor] = None) -> Tensor: - """Calculates the OKS loss. + """Calculates the OKS metric. Args: output (Tensor): Predicted keypoints in shape N x k x 2, where N @@ -70,7 +69,7 @@ def compute_oks(self, where 4 are the xyxy coordinates. Returns: - Tensor: The calculated OKS loss. + Tensor: The calculated OKS. """ dist = torch.norm(output - target, dim=-1) diff --git a/projects/petr/models/match_costs.py b/projects/petr/models/match_costs.py index 18b350d0b1..2e17bb4087 100644 --- a/projects/petr/models/match_costs.py +++ b/projects/petr/models/match_costs.py @@ -1,7 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import Optional -import torch from mmdet.models.task_modules.assigners.match_cost import BaseMatchCost from mmdet.registry import TASK_UTILS from mmengine.structures import InstanceData @@ -12,33 +11,63 @@ @TASK_UTILS.register_module() class KptL1Cost(BaseMatchCost): + """This class computes the L1 cost between predicted and ground truth + keypoints. + + The cost is computed based on the normalized difference between the + keypoints. The keypoints visibility is also taken into account while + calculating the cost. + """ def __call__(self, pred_instances: InstanceData, gt_instances: InstanceData, img_meta: Optional[dict] = None, **kwargs) -> Tensor: + """Compute the L1 cost between predicted and ground truth keypoints. + + Args: + pred_instances (InstanceData): Predicted instances data. + gt_instances (InstanceData): Ground truth instances data. + img_meta (dict, optional): Meta data of the image. Defaults + to None. + + Returns: + Tensor: L1 cost between predicted and ground truth keypoints. + """ + # Extract keypoints from predicted and ground truth instances pred_keypoints = pred_instances.keypoints gt_keypoints = gt_instances.keypoints + + # Get the visibility of keypoints and normalize it gt_keypoints_visible = gt_instances.keypoints_visible - gt_keypoints_visible = gt_keypoints_visible / (2 * gt_keypoints_visible.sum(dim=1, keepdim=True) + 1e-8) - - # normalized + gt_keypoints_visible = gt_keypoints_visible / ( + 2 * gt_keypoints_visible.sum(dim=1, keepdim=True) + 1e-8) + + # Normalize keypoints based on image shape img_h, img_w = img_meta['img_shape'] factor = gt_keypoints.new_tensor([img_w, img_h]).reshape(1, 1, 2) gt_keypoints = (gt_keypoints / factor).unsqueeze(0) gt_keypoints_visible = gt_keypoints_visible.unsqueeze(0).unsqueeze(-1) pred_keypoints = (pred_keypoints / factor).unsqueeze(1) - + + # Calculate L1 cost considering visibility of keypoints diff = (pred_keypoints - gt_keypoints) * gt_keypoints_visible kpt_cost = diff.flatten(2).norm(dim=2, p=1) - + return kpt_cost * self.weight @TASK_UTILS.register_module() class OksCost(BaseMatchCost, OksLoss): + """This class computes the OKS (Object Keypoint Similarity) cost between + predicted and ground truth keypoints. + + It normalizes keypoints based on image shape, then calculates the OKS using + a method from the OksLoss class. It also includes visibility and bounding + box information in the calculation. + """ def __init__(self, metainfo: Optional[str] = None, weight: float = 1.0): OksLoss.__init__(self, metainfo, weight) @@ -49,19 +78,32 @@ def __call__(self, gt_instances: InstanceData, img_meta: Optional[dict] = None, **kwargs) -> Tensor: + """Compute the OKS cost between predicted and ground truth keypoints. + + Args: + pred_instances (InstanceData): Predicted instances data. + gt_instances (InstanceData): Ground truth instances data. + img_meta (dict, optional): Meta data of the image. Defaults + to None. + + Returns: + Tensor: OKS cost between predicted and ground truth keypoints. + """ + # Extract keypoints and bounding boxes pred_keypoints = pred_instances.keypoints gt_keypoints = gt_instances.keypoints gt_bboxes = gt_instances.bboxes gt_keypoints_visible = gt_instances.keypoints_visible - # normalized + # Normalize keypoints and bounding boxes based on image shape img_h, img_w = img_meta['img_shape'] factor = gt_keypoints.new_tensor([img_w, img_h]).reshape(1, 1, 2) gt_keypoints = (gt_keypoints / factor).unsqueeze(0) pred_keypoints = (pred_keypoints / factor).unsqueeze(1) gt_bboxes = (gt_bboxes.reshape(-1, 2, 2) / factor).reshape(1, -1, 4) + # Calculate OKS cost kpt_cost = self.compute_oks(pred_keypoints, gt_keypoints, gt_keypoints_visible, gt_bboxes) kpt_cost = -kpt_cost diff --git a/projects/petr/models/petr.py b/projects/petr/models/petr.py index a570248890..1a9673dfb2 100644 --- a/projects/petr/models/petr.py +++ b/projects/petr/models/petr.py @@ -22,6 +22,22 @@ @MODELS.register_module() class PETR(DeformableDETR): + r"""Implementation of `End-to-End Multi-Person Pose Estimation with + Transformers `_ + + Code is modified from the `official github repo + `_. + + Args: + num_keypoints (int): Numbder of Keypoints. Defaults to 17. + hm_encoder (:obj:`ConfigDict` or dict, optional): Config of the + heatmap encoder. Defaults to None. + kpt_decoder (:obj:`ConfigDict` or dict, optional): Config for the + keypoint refine decoder. Defaults to None. + """ + _version = 2 def __init__(self, num_keypoints: int = 17, @@ -93,37 +109,58 @@ def forward_transformer(self, img_feats: Tuple[Tensor], batch_data_samples: OptSampleList = None, test_mode: bool = True) -> Dict: - """Forward process of Transformer, which includes four steps: - 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'. We - summarized the parameters flow of the existing DETR-like detector, - which can be illustrated as follow: + """Forward process of Transformer in PETR. + + This function consists of seven stages: 'pre_transformer', + 'forward_encoder', 'pre_decoder', 'forward_decoder', + 'pre_kpt_decoder', 'forward_kpt_decoder', and 'forward_kpt_head'. + It takes image features (img_feats) and batch data samples + (batch_data_samples) as inputs and performs transformations at + each stage. The output is a dictionary of inputs to be used for the + bounding box head function (bbox_head). .. code:: text - img_feats & batch_data_samples - | - V - +-----------------+ - | pre_transformer | - +-----------------+ - | | - | V - | +-----------------+ - | | forward_encoder | - | +-----------------+ - | | - | V - | +---------------+ - | | pre_decoder | - | +---------------+ - | | | - V V | - +-----------------+ | - | forward_decoder | | - +-----------------+ | - | | - V V - head_inputs_dict + img_feats & batch_data_samples + | + V + +-----------------+ + | pre_transformer | + +-----------------+ + | | + | V + | +-----------------+ + | | forward_encoder | + | +-----------------+ + | | + | V + | +---------------+ + | | pre_decoder | + | +---------------+ + | | | + V V | + +-----------------+ | + | forward_decoder | | + +-----------------+ | + | | + V V + +-----------------+ | + | pre_kpt_decoder | | + +-----------------+ | + | | + V V + +--------------------+ | + | forward_kpt_decoder| | + +--------------------+ | + | | + V V + +----------------+ | + |forward_kpt_head| | + +----------------+ | + | | + V V + head_inputs_dict + Args: img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each @@ -132,96 +169,44 @@ def forward_transformer(self, batch data samples. It usually includes information such as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. Defaults to None. + test_mode (bool, optional): If True, the function operates in test + mode. Defaults to True. Returns: - dict: The dictionary of bbox_head function inputs, which always - includes the `hidden_states` of the decoder output and may contain - `references` including the initial and intermediate references. + head_inputs_dict (dict): The dictionary of bbox_head function + inputs. Always includes 'hidden_states' from the decoder output + and may contain 'references' including the initial and + intermediate references. The specific contents of this dict + differ based on whether the function is operating in test_mode + or not. In test_mode, 'det_labels' and 'det_scores' are + included. In training mode, it includes additional elements + such as 'enc_outputs_class', 'enc_outputs_coord', + 'all_layers_classes', 'all_layers_coords', 'hm_memory', + and 'hm_mask'. """ encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer( img_feats, batch_data_samples) - # encoder_inputs_dict - # |- feat: [bs, mlv_shape, 256] - # |- feat_mask: [bs, mlv_shape] - # |- feat_pos: [bs, mlv_shape, 256] - # |- spatial_shapes: [4, 2] - # |- level_start_index: [4] - # |- valid_ratios: [bs, 4, 2] - # decoder_inputs_dict - # |- memory_mask: [bs, mlv_shape] - # |- spatial_shapes [4, 2] - # |- level_start_index [4] - # |- valid_ratios [bs, 4, 2] encoder_outputs_dict, heatmap_dict = self.forward_encoder( **encoder_inputs_dict, test_mode=test_mode) - # encoder_outputs_dict - # |- memory: [bs, mlv_shape, 256] - # |- memory_mask: [bs, mlv_shape] (feat_mask) - # |- spatial_shapes: [4, 2] - # heatmap_dict - # |- hm_memory: [bs, lv0_h, lv0_w, 256] - # |- hm_mask: [bs, lv0_h, lv0_w] tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict) - # tmp_dec_in - # |- query: [bs, num_queries, 256] - # |- query_pos: [bs, num_queries, 256] - # |- memory: [bs, mlv_shape, 256] (memory) - # |- reference_points: [bs, num_queries, 2*num_keypoints] - # head_inputs_dict (train only) - # |- enc_outputs_class: [bs, mlv_shape, 1] - # |- enc_outputs_coord: [bs, mlv_shape, 34] decoder_inputs_dict.update(tmp_dec_in) - # decoder_inputs_dict - # |- query: [bs, num_queries, 256] - # |- query_pos: [bs, num_queries, 256] - # |- memory: [bs, mlv_shape, 256] (memory) - # |- memory_mask: [bs, mlv_shape] - # |- spatial_shapes [4, 2] - # |- level_start_index [4] - # |- valid_ratios [bs, 4, 2] decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict) - # decoder_outputs_dict - # |- hidden_states [3, bs, num_queries, 256] - # |- references [1, 300, 34] * 4 - # |- all_layers_classes [3, bs, num_queries, 1] - # |- all_layers_coords [3, bs, num_queries, 2*num_keypoints] kpt_decoder_inputs_dict = self.pre_kpt_decoder( **decoder_outputs_dict, batch_data_samples=batch_data_samples, test_mode=test_mode) - # kpt_decoder_inputs_dict - # |- pos_kpt_coords: [max_inst, 2*num_keypoints] - # |- pos_img_inds: [max_inst] - # |- det_labels: [max_inst] (test only) - # |- det_scores: [max_inst] (test only) kpt_decoder_inputs_dict.update(decoder_inputs_dict) - # kpt_decoder_inputs_dict - # |- pos_kpt_coords: [max_inst, 2*num_keypoints] - # |- pos_img_inds: [max_inst] - # |- det_labels: [max_inst] - # |- query: [bs, num_queries, 256] - # |- query_pos: [bs, num_queries, 256] - # |- memory: [bs, mlv_shape, 256] (memory) - # |- memory_mask: [bs, mlv_shape] - # |- spatial_shapes [4, 2] - # |- level_start_index [4] - # |- valid_ratios [bs, 4, 2] kpt_decoder_outputs_dict = self.forward_kpt_decoder( **kpt_decoder_inputs_dict) - # kpt_decoder_outputs_dict (test) - # |- inter_states: [2, max_inst, num_keypoints, 256] - # |- reference_points: [max_inst, num_keypoints, 2] - # |- inter_references: [2, max_inst, num_keypoints, 2] dec_outputs_coord = self.forward_kpt_head(**kpt_decoder_outputs_dict) - # dec_outputs_coord: [2, max_inst, num_keypoints, 2] head_inputs_dict['dec_outputs_coord'] = dec_outputs_coord if test_mode: @@ -235,16 +220,6 @@ def forward_transformer(self, 'all_layers_classes'] head_inputs_dict['all_layers_coords'] = decoder_outputs_dict[ 'all_layers_coords'] - # head_inputs_dict - # |- enc_outputs_class: [bs, mlv_shape, 1] (train only) - # |- enc_outputs_coord: [bs, mlv_shape, 34] (train only) - # |- all_layers_classes [3, bs, num_queries, 1] (train only) - # |- all_layers_coords [3, bs, num_queries, 2*num_keypoints] (to) - # |- hm_memory: [bs, lv0_h, lv0_w, 256] (train only) - # |- hm_mask: [bs, lv0_h, lv0_w] (train only) - # |- dec_outputs_coord: [2, max_inst, num_keypoints, 2] - # |- det_labels: [max_inst] (test only) - # |- det_scores: [max_inst] (test only) return head_inputs_dict @@ -262,23 +237,10 @@ def loss(self, batch_inputs: Tensor, Returns: dict: A dictionary of loss components """ - # torch.save(dict( - # batch_inputs=batch_inputs.cpu(), - # batch_data_samples=batch_data_samples - # ), 'notebooks/train_proc_tensors/img+ds.pth') - # exit(0) img_feats = self.extract_feat(batch_inputs) head_inputs_dict = self.forward_transformer( img_feats, batch_data_samples, test_mode=False) - # head_inputs_dict - # |- enc_outputs_class: [bs, mlv_shape, 1] - # |- enc_outputs_coord: [bs, mlv_shape, 34] - # |- all_layers_classes [3, bs, num_queries, 1] - # |- all_layers_coords [3, bs, num_queries, 2*num_keypoints] - # |- hm_memory: [bs, lv0_h, lv0_w, 256] - # |- hm_mask: [bs, lv0_h, lv0_w] - # |- dec_outputs_coord: [2, max_inst, num_keypoints, 2] losses = self.bbox_head.loss( **head_inputs_dict, batch_data_samples=batch_data_samples) @@ -316,10 +278,6 @@ def predict(self, img_feats = self.extract_feat(batch_inputs) head_inputs_dict = self.forward_transformer( img_feats, batch_data_samples, test_mode=True) - # head_inputs_dict - # |- dec_outputs_coord: [2, max_inst, num_keypoints, 2] - # |- det_labels: [max_inst] - # |- det_scores: [max_inst] results_list = self.bbox_head.predict( **head_inputs_dict, @@ -339,11 +297,6 @@ def forward_encoder(self, test_mode: bool = True) -> Dict: """Forward with Transformer encoder. - The forward procedure of the transformer is defined as: - 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' - More details can be found at `TransformerDetector.forward_transformer` - in `mmdet/detector/base_detr.py`. - Args: feat (Tensor): Sequential features, has shape (bs, num_feat_points, dim). @@ -359,6 +312,8 @@ def forward_encoder(self, valid_ratios (Tensor): The ratios of the valid width and the valid height relative to the width and the height of features in all levels, has shape (bs, num_levels, 2). + test_mode (bool, optional): If True, the function operates in test + mode. Defaults to True. Returns: dict: The dictionary of encoder outputs, which includes the @@ -407,11 +362,6 @@ def pre_decoder(self, memory: Tensor, memory_mask: Tensor, """Prepare intermediate variables before entering Transformer decoder, such as `query`, `query_pos`, and `reference_points`. - The forward procedure of the transformer is defined as: - 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder' - More details can be found at `TransformerDetector.forward_transformer` - in `mmdet/detector/base_detr.py`. - Args: memory (Tensor): The output embeddings of the Transformer encoder, has shape (bs, num_feat_points, dim). @@ -456,14 +406,6 @@ def pre_decoder(self, memory: Tensor, memory_mask: Tensor, enc_outputs_coord_unact[..., 0::2] += output_proposals[..., 0:1] enc_outputs_coord_unact[..., 1::2] += output_proposals[..., 1:2] enc_outputs_coord = enc_outputs_coord_unact.sigmoid() - # We only use the first channel in enc_outputs_class as foreground, - # the other (num_classes - 1) channels are actually not used. - # Its targets are set to be 0s, which indicates the first - # class (foreground) because we use [0, num_classes - 1] to - # indicate class labels, background class is indicated by - # num_classes (similar convention in RPN). - # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa - # This follows the official implementation of Deformable DETR. topk_proposals = torch.topk( enc_outputs_class[..., 0], self.num_queries, dim=1)[1] topk_coords_unact = torch.gather( @@ -555,7 +497,15 @@ def pre_kpt_decoder(self, batch_data_samples, test_mode=False, **kwargs): + """Prepares the inputs for the keypoint decoder. + Args: + all_layers_classes (Tensor): Classification scores of all layers + all_layers_coords (Tensor): Coordinates of keypoints of all layers + batch_data_samples (list): List of samples in a batch + test_mode (bool, optional): If True, the function will run in test + mode. Defaults to False. + """ cls_scores = all_layers_classes[-1] kpt_coords = all_layers_coords[-1] @@ -629,6 +579,18 @@ def pre_kpt_decoder(self, def forward_kpt_decoder(self, memory, memory_mask, pos_kpt_coords, pos_img_inds, spatial_shapes, level_start_index, valid_ratios, **kwargs): + """Runs the keypoint decoder forward pass. + + Args: + memory (Tensor): The output embeddings from the Transformer + encoder. + memory_mask (Tensor): The mask of the memory. + pos_kpt_coords (Tensor): Positive keypoint coordinates. + pos_img_inds (Tensor): Image indices of positive keypoints. + spatial_shapes (Tensor): Spatial shapes of features. + level_start_index (Tensor): Start index of each level. + valid_ratios (Tensor): Valid ratios of all images. + """ kpt_query_embedding = self.kpt_query_embedding.weight query_pos, query = torch.split( @@ -666,6 +628,17 @@ def forward_kpt_decoder(self, memory, memory_mask, pos_kpt_coords, def forward_kpt_head(self, inter_states, reference_points, inter_references): + """Runs the keypoint head forward pass. + + Args: + inter_states (Tensor): Intermediate states from the keypoint + decoder. + reference_points (Tensor): Reference points from the keypoint + decoder. + inter_references (Tensor): Intermediate reference points from + the keypoint decoder. + """ + outputs_kpts = [] for lvl in range(inter_states.shape[0]): if lvl == 0: @@ -689,12 +662,8 @@ def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args, The hook will be automatically registered during initialization. """ - return - - if 'mmengine_version' in local_meta: - return - if local_meta.get('mmdet_verion', '0') > '3': + if local_meta.get('version', self._version) >= self._version: return mappings = OrderedDict() diff --git a/projects/petr/models/petr_head.py b/projects/petr/models/petr_head.py index 6c03d15910..ac4ecd3c5c 100644 --- a/projects/petr/models/petr_head.py +++ b/projects/petr/models/petr_head.py @@ -4,8 +4,6 @@ import torch import torch.nn as nn -from mmengine.model import bias_init_with_prob, constant_init, normal_init - from mmcv.cnn import Linear from mmdet.models import inverse_sigmoid from mmdet.models.dense_heads import DeformableDETRHead @@ -13,12 +11,37 @@ from mmdet.registry import MODELS from mmdet.structures import SampleList from mmdet.utils import InstanceList, reduce_mean +from mmengine.model import bias_init_with_prob, constant_init, normal_init from mmengine.structures import InstanceData from torch import Tensor @MODELS.register_module() class PETRHead(DeformableDETRHead): + r"""Head of PETR: End-to-End Multi-Person Pose Estimation + with Transformers. + + Code is modified from the `official github repo + `_. + + More details can be found in the `paper + `_ . + + Args: + num_keypoints (int): Number of keypoints. Defaults to 17. + num_pred_kpt_layer (int): The number of the keypoint refine decoder + layers. Defaults to 2. + loss_reg (dict): The configuration dict of regression loss for + outputs of refine decoders. + loss_reg_aux (dict): The configuration dict of regression loss for + outputs of decoders. + loss_oks (dict): The configuration dict of oks loss for outputs + of refine decoders. + loss_oks_aux (dict): The configuration dict of oks loss for + outputs of decoders. + loss_hm (dict): The configuration dict of heatmap loss. + """ def __init__(self, num_keypoints: int = 17, @@ -81,7 +104,7 @@ def _init_layers(self) -> None: self.heatmap_fc = Linear(self.embed_dims, self.num_keypoints) def init_weights(self) -> None: - """Initialize weights of the Deformable DETR head.""" + """Initialize weights of the PETR head.""" if self.loss_cls.use_sigmoid: bias_init = bias_init_with_prob(0.01) for m in self.cls_branches: @@ -96,33 +119,7 @@ def init_weights(self) -> None: def forward(self, hidden_states: Tensor, references: List[Tensor]) -> Tuple[Tensor]: - """Forward function. - - Args: - hidden_states (Tensor): Hidden states output from each decoder - layer, has shape (num_decoder_layers, bs, num_queries, dim). - references (list[Tensor]): List of the reference from the decoder. - The first reference is the `init_reference` (initial) and the - other num_decoder_layers(6) references are `inter_references` - (intermediate). The `init_reference` has shape (bs, - num_queries, 4) when `as_two_stage` of the detector is `True`, - otherwise (bs, num_queries, 2). Each `inter_reference` has - shape (bs, num_queries, 4) when `with_box_refine` of the - detector is `True`, otherwise (bs, num_queries, 2). The - coordinates are arranged as (cx, cy) when the last dimension is - 2, and (cx, cy, w, h) when it is 4. - - Returns: - tuple[Tensor]: results of head containing the following tensor. - - - all_layers_outputs_classes (Tensor): Outputs from the - classification head, has shape (num_decoder_layers, bs, - num_queries, cls_out_channels). - - all_layers_outputs_coords (Tensor): Sigmoid outputs from the - regression head with normalized coordinate format (cx, cy, w, - h), has shape (num_decoder_layers, bs, num_queries, 4) with the - last dimension arranged as (cx, cy, w, h). - """ + """Forward function.""" all_layers_outputs_classes = [] all_layers_outputs_coords = [] @@ -159,31 +156,7 @@ def predict(self, batch_data_samples: SampleList, rescale: bool = True) -> InstanceList: """Perform forward propagation and loss calculation of the detection - head on the queries of the upstream network. - - Args: - hidden_states (Tensor): Hidden states output from each decoder - layer, has shape (num_decoder_layers, num_queries, bs, dim). - references (list[Tensor]): List of the reference from the decoder. - The first reference is the `init_reference` (initial) and the - other num_decoder_layers(6) references are `inter_references` - (intermediate). The `init_reference` has shape (bs, - num_queries, 4) when `as_two_stage` of the detector is `True`, - otherwise (bs, num_queries, 2). Each `inter_reference` has - shape (bs, num_queries, 4) when `with_box_refine` of the - detector is `True`, otherwise (bs, num_queries, 2). The - coordinates are arranged as (cx, cy) when the last dimension is - 2, and (cx, cy, w, h) when it is 4. - batch_data_samples (list[:obj:`DetDataSample`]): The Data - Samples. It usually includes information such as - `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. - rescale (bool, optional): If `True`, return boxes in original - image space. Defaults to `True`. - - Returns: - list[obj:`InstanceData`]: Detection results of each image - after the post process. - """ + head on the queries of the upstream network.""" batch_img_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] @@ -244,43 +217,7 @@ def loss(self, enc_outputs_class: Tensor, enc_outputs_coord: Tensor, hm_memory: Tensor, hm_mask: Tensor, dec_outputs_coord: Tensor, batch_data_samples: SampleList) -> dict: """Perform forward propagation and loss calculation of the detection - head on the queries of the upstream network. - - Args: - hidden_states (Tensor): Hidden states output from each decoder - layer, has shape (num_decoder_layers, num_queries, bs, dim). - references (list[Tensor]): List of the reference from the decoder. - The first reference is the `init_reference` (initial) and the - other num_decoder_layers(6) references are `inter_references` - (intermediate). The `init_reference` has shape (bs, - num_queries, 4) when `as_two_stage` of the detector is `True`, - otherwise (bs, num_queries, 2). Each `inter_reference` has - shape (bs, num_queries, 4) when `with_box_refine` of the - detector is `True`, otherwise (bs, num_queries, 2). The - coordinates are arranged as (cx, cy) when the last dimension is - 2, and (cx, cy, w, h) when it is 4. - enc_outputs_class (Tensor): The score of each point on encode - feature map, has shape (bs, num_feat_points, cls_out_channels). - Only when `as_two_stage` is `True` it would be passed in, - otherwise it would be `None`. - enc_outputs_coord (Tensor): The proposal generate from the encode - feature map, has shape (bs, num_feat_points, 4) with the last - dimension arranged as (cx, cy, w, h). Only when `as_two_stage` - is `True` it would be passed in, otherwise it would be `None`. - batch_data_samples (list[:obj:`DetDataSample`]): The Data - Samples. It usually includes information such as - `gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. - - Returns: - dict: A dictionary of loss components. - """ - # enc_outputs_class: [bs, mlv_shape, 1] - # enc_outputs_coord: [bs, mlv_shape, 2*num_keypoints] - # all_layers_classes [3, bs, num_queries, 1] - # all_layers_coords [3, bs, num_queries, 2*num_keypoints] - # hm_memory: [bs, lv0_h, lv0_w, 256] - # hm_mask: [bs, lv0_h, lv0_w] - # dec_outputs_coord: [2, max_inst, num_keypoints, 2] + head on the queries of the upstream network.""" batch_gt_instances = [] batch_img_metas = [] @@ -352,24 +289,8 @@ def loss_by_feat_single(self, batch_img_metas: List[dict], compute_oks_loss: bool = True) -> Tuple[Tensor]: """Loss function for outputs from a single decoder layer of a single - feature level. + feature level.""" - Args: - cls_scores (Tensor): Box score logits from a single decoder layer - for all images, has shape (bs, num_queries, cls_out_channels). - kpt_preds (Tensor): Sigmoid outputs from a single decoder layer - for all images, with normalized coordinate (cx, cy, w, h) and - shape (bs, num_queries, 4). - batch_gt_instances (list[:obj:`InstanceData`]): Batch of - gt_instance. It usually includes ``bboxes`` and ``labels`` - attributes. - batch_img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - - Returns: - Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and - `loss_iou`. - """ # cls_scores [bs, num_queries, 1] # kpt_preds [bs, num_queries, 2*num_keypoitns] @@ -446,24 +367,8 @@ def loss_by_feat_single(self, def loss_refined_kpts(self, kpt_preds: Tensor, batch_img_metas: List[dict]) -> Tuple[Tensor]: """Loss function for outputs from a single decoder layer of a single - feature level. + feature level.""" - Args: - cls_scores (Tensor): Box score logits from a single decoder layer - for all images, has shape (bs, num_queries, cls_out_channels). - kpt_preds (Tensor): Sigmoid outputs from a single decoder layer - for all images, with normalized coordinate (cx, cy, w, h) and - shape (bs, num_queries, 4). - batch_gt_instances (list[:obj:`InstanceData`]): Batch of - gt_instance. It usually includes ``bboxes`` and ``labels`` - attributes. - batch_img_metas (list[dict]): Meta information of each image, e.g., - image size, scaling factor, etc. - - Returns: - Tuple[Tensor]: A tuple including `loss_cls`, `loss_box` and - `loss_iou`. - """ # kpt_preds [num_selected, num_keypoints, 2] bbox_targets_list = self._target_buffer['bbox_targets_list'] kpt_targets_list = self._target_buffer['kpt_targets_list'] @@ -519,6 +424,7 @@ def loss_refined_kpts(self, kpt_preds: Tensor, return loss_kpt, loss_oks def loss_heatmap(self, hm_memory, hm_mask, batch_gt_fields): + """Heatmap loss function for outputs from the heatmap encoder.""" # compute heatmap predition pred_heatmaps = self.heatmap_fc(hm_memory) @@ -568,7 +474,8 @@ def get_targets(self, - labels_list (list[Tensor]): Labels for all images. - label_weights_list (list[Tensor]): Label weights for all images. - bbox_targets_list (list[Tensor]): BBox targets for all images. - - bbox_weights_list (list[Tensor]): BBox weights for all images. + - kpt_targets_list (list[Tensor]): Keypoint targets for all images. + - kpt_weights_list (list[Tensor]): Keypoint weights for all images. - num_total_pos (int): Number of positive samples in all images. - num_total_neg (int): Number of negative samples in all images. """ @@ -617,7 +524,8 @@ def _get_targets_single(self, cls_score: Tensor, kpt_pred: Tensor, - labels (Tensor): Labels of each image. - label_weights (Tensor]): Label weights of each image. - bbox_targets (Tensor): BBox targets of each image. - - bbox_weights (Tensor): BBox weights of each image. + - kpt_targets (Tensor): Keypoint targets of each image. + - kpt_weights (Tensor): Keypoint weights of each image. - pos_inds (Tensor): Sampled positive indices for each image. - neg_inds (Tensor): Sampled negative indices for each image. """ diff --git a/projects/petr/models/transformers.py b/projects/petr/models/transformers.py index 8f02570a1f..ff24316ff0 100644 --- a/projects/petr/models/transformers.py +++ b/projects/petr/models/transformers.py @@ -131,6 +131,7 @@ def forward(self, class PetrTransformerDecoderLayer(DeformableDetrTransformerDecoderLayer): + """Decoder layer of PETR.""" def _init_layers(self) -> None: """Initialize self_attn, cross-attn, ffn, and norms.""" @@ -147,6 +148,7 @@ def _init_layers(self) -> None: class PetrTransformerDecoder(DeformableDetrTransformerDecoder): + """Transformer Decoder of PETR.""" def __init__(self, num_keypoints: int, *args, **kwargs): self.num_keypoints = num_keypoints From 3d72a164b3f73f404f1568578eabaf414a6044ec Mon Sep 17 00:00:00 2001 From: lupeng Date: Wed, 31 May 2023 16:57:40 +0800 Subject: [PATCH 07/12] fix typo --- projects/README.md | 6 ++++++ projects/petr/README.md | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/projects/README.md b/projects/README.md index 1089af8194..62126f9a85 100644 --- a/projects/README.md +++ b/projects/README.md @@ -61,3 +61,9 @@ We also provide some documentation listed below to help you get started:
- **What's next? Join the rank of *MMPose contributors* by creating a new project**! + +- **[:bulb:PETR](./petr)**: End-to-End Multi-Person Pose Estimation with Transformers + +
+ +

diff --git a/projects/petr/README.md b/projects/petr/README.md index fcece2a17f..b8dae3511b 100644 --- a/projects/petr/README.md +++ b/projects/petr/README.md @@ -1,4 +1,4 @@ -# YOLOX-Pose +# PETR This project implements PETR (Pose Estimation with TRansformers), an end-to-end multi-person pose estimation framework introduced in the CVPR 2022 paper **End-to-End Multi-Person Pose Estimation with Transformers**. PETR is a novel, end-to-end multi-person pose estimation method that treats pose estimation as a hierarchical set prediction problem. By leveraging attention mechanisms, PETR can adaptively focus on features most relevant to target keypoints, thereby overcoming feature misalignment issues in pose estimation. @@ -23,7 +23,7 @@ export PYTHONPATH=`pwd`:$PYTHONPATH ### Inference -Users can apply YOLOX-Pose models to estimate human poses using the inferencer found in the MMPose core package. Use the command below: +Users can apply PETR models to estimate human poses using the inferencer found in the MMPose core package. Use the command below: ```shell python demo/inferencer_demo.py $INPUTS \ From 5c29bc9c1d4abc9dc3ad76698d1149a2b2012ab1 Mon Sep 17 00:00:00 2001 From: lupeng Date: Wed, 31 May 2023 16:59:03 +0800 Subject: [PATCH 08/12] add soft links for demo and tools --- projects/petr/demo | 1 + projects/petr/tools | 1 + 2 files changed, 2 insertions(+) create mode 120000 projects/petr/demo create mode 120000 projects/petr/tools diff --git a/projects/petr/demo b/projects/petr/demo new file mode 120000 index 0000000000..bf71256cd3 --- /dev/null +++ b/projects/petr/demo @@ -0,0 +1 @@ +../../demo \ No newline at end of file diff --git a/projects/petr/tools b/projects/petr/tools new file mode 120000 index 0000000000..31941e941d --- /dev/null +++ b/projects/petr/tools @@ -0,0 +1 @@ +../../tools \ No newline at end of file From 6508e33da8fa97a8563413e1414a8758fb494d32 Mon Sep 17 00:00:00 2001 From: lupeng Date: Sat, 10 Jun 2023 01:47:30 +0800 Subject: [PATCH 09/12] modify some training settings --- .../petr/datasets/bbox_keypoint_structure.py | 16 +++++ projects/petr/datasets/coco_dataset.py | 61 ++++++++++++++++++- projects/petr/datasets/transforms.py | 50 ++++++++++++++- projects/petr/models/losses.py | 5 +- projects/petr/models/match_costs.py | 17 ++---- projects/petr/models/petr_head.py | 35 ++++++----- 6 files changed, 150 insertions(+), 34 deletions(-) diff --git a/projects/petr/datasets/bbox_keypoint_structure.py b/projects/petr/datasets/bbox_keypoint_structure.py index 6b385f2f09..d53df0e3f4 100644 --- a/projects/petr/datasets/bbox_keypoint_structure.py +++ b/projects/petr/datasets/bbox_keypoint_structure.py @@ -43,6 +43,7 @@ def __init__(self, data: Union[Tensor, np.ndarray], keypoints: Union[Tensor, np.ndarray], keypoints_visible: Union[Tensor, np.ndarray], + area: Union[Tensor, np.ndarray], dtype: Optional[torch.dtype] = None, device: Optional[DeviceType] = None, clone: bool = True, @@ -58,23 +59,28 @@ def __init__(self, assert len(data) == len(keypoints) assert len(data) == len(keypoints_visible) + assert len(data) == len(area) assert keypoints.ndim == 3 assert keypoints_visible.ndim == 2 keypoints = torch.as_tensor(keypoints) keypoints_visible = torch.as_tensor(keypoints_visible) + area = torch.as_tensor(area) if device is not None: keypoints = keypoints.to(device=device) keypoints_visible = keypoints_visible.to(device=device) + area = area.to(device=device) if clone: keypoints = keypoints.clone() keypoints_visible = keypoints_visible.clone() + area = area.clone() self.keypoints = keypoints self.keypoints_visible = keypoints_visible + self.area = area self.flip_indices = flip_indices def flip_(self, @@ -122,6 +128,7 @@ def rescale_(self, scale_factor: Tuple[float, float]) -> None: boxes = self.tensor assert len(scale_factor) == 2 + self.area = self.area * scale_factor[0] * scale_factor[1] self.tensor = boxes * boxes.new_tensor(scale_factor).repeat(2) scale_factor = self.keypoints.new_tensor(scale_factor).reshape(1, 1, 2) self.keypoints = self.keypoints * scale_factor @@ -152,6 +159,7 @@ def project_(self, homography_matrix: Union[Tensor, np.ndarray]) -> None: homography_matrix (Tensor or np.ndarray): A 3x3 tensor or ndarray representing the homography matrix for the transformation. """ + boxes = self.tensor if isinstance(homography_matrix, np.ndarray): homography_matrix = boxes.new_tensor(homography_matrix) @@ -214,11 +222,13 @@ def cat(cls: Type[T], box_list: Sequence[T], dim: int = 0) -> T: dim=dim) th_kpt_vis_list = torch.cat( [boxes.keypoints_visible for boxes in box_list], dim=dim) + th_area_list = torch.cat([boxes.area for boxes in box_list], dim=dim) flip_indices = box_list[0].flip_indices return cls( th_box_list, th_kpt_list, th_kpt_vis_list, + th_area_list, clone=False, flip_indices=flip_indices) @@ -239,14 +249,17 @@ def __getitem__(self: T, index: IndexType) -> T: boxes = boxes[index] keypoints = self.keypoints[index] keypoints_visible = self.keypoints_visible[index] + area = self.area[index] if boxes.dim() == 1: boxes = boxes.reshape(1, -1) keypoints = keypoints.reshape(1, -1, 2) keypoints_visible = keypoints_visible.reshape(1, -1) + area = area.reshape(1, ) return type(self)( boxes, keypoints, keypoints_visible, + area, flip_indices=self.flip_indices, clone=False) @@ -263,6 +276,7 @@ def __deepcopy__(self, memo): other.tensor = self.tensor.clone() other.keypoints = self.keypoints.clone() other.keypoints_visible = self.keypoints_visible.clone() + other.area = self.area.clone() other.flip_indices = deepcopy(self.flip_indices) return other @@ -272,6 +286,7 @@ def clone(self: T) -> T: self.tensor, self.keypoints, self.keypoints_visible, + self.area, flip_indices=self.flip_indices, clone=True) @@ -281,5 +296,6 @@ def to(self: T, *args, **kwargs) -> T: self.tensor.to(*args, **kwargs), self.keypoints.to(*args, **kwargs), self.keypoints_visible.to(*args, **kwargs), + self.area.to(*args, **kwargs), flip_indices=self.flip_indices, clone=False) diff --git a/projects/petr/datasets/coco_dataset.py b/projects/petr/datasets/coco_dataset.py index 165aa2f251..23da016f1c 100644 --- a/projects/petr/datasets/coco_dataset.py +++ b/projects/petr/datasets/coco_dataset.py @@ -1,5 +1,6 @@ import copy -from typing import Optional +from itertools import filterfalse, groupby +from typing import Dict, List, Optional import numpy as np from mmdet.registry import DATASETS @@ -61,6 +62,7 @@ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: # Addition: add image width and height 'width': img_w, 'height': img_h, + 'area': ann['area'], 'bbox': bbox, 'bbox_score': np.ones(1, dtype=np.float32), 'num_keypoints': num_keypoints, @@ -79,3 +81,60 @@ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]: data_info['crowd_index'] = img['crowdIndex'] return data_info + + def _get_bottomup_data_infos(self, instance_list: List[Dict], + image_list: List[Dict]) -> List[Dict]: + """Organize the data list in bottom-up mode.""" + + # bottom-up data list + data_list_bu = [] + + used_img_ids = set() + + # group instances by img_id + for img_id, data_infos in groupby(instance_list, + lambda x: x['img_id']): + used_img_ids.add(img_id) + data_infos = list(data_infos) + + # image data + img_path = data_infos[0]['img_path'] + data_info_bu = { + 'img_id': img_id, + 'img_path': img_path, + 'width': data_infos[0]['width'], + 'height': data_infos[0]['height'], + } + + for key in data_infos[0].keys(): + if key not in data_info_bu: + seq = [d[key] for d in data_infos] + if isinstance(seq[0], np.ndarray): + seq = np.concatenate(seq, axis=0) + seq = np.array(seq) if key == 'area' else seq + data_info_bu[key] = seq + + # The segmentation annotation of invalid objects will be used + # to generate valid region mask in the pipeline. + invalid_segs = [] + for data_info_invalid in filterfalse(self._is_valid_instance, + data_infos): + if 'segmentation' in data_info_invalid: + invalid_segs.append(data_info_invalid['segmentation']) + data_info_bu['invalid_segs'] = invalid_segs + + data_list_bu.append(data_info_bu) + + # add images without instance for evaluation + if self.test_mode: + for img_info in image_list: + if img_info['img_id'] not in used_img_ids: + data_info_bu = { + 'img_id': img_info['img_id'], + 'img_path': img_info['img_path'], + 'id': list(), + 'raw_ann_info': None, + } + data_list_bu.append(data_info_bu) + + return data_list_bu diff --git a/projects/petr/datasets/transforms.py b/projects/petr/datasets/transforms.py index be4c748505..efd130391f 100644 --- a/projects/petr/datasets/transforms.py +++ b/projects/petr/datasets/transforms.py @@ -1,10 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +import random from typing import Union import numpy as np from mmcv.transforms import BaseTransform +from mmcv.transforms.utils import cache_randomness from mmdet.datasets.transforms import FilterAnnotations as FilterDetAnnotations from mmdet.datasets.transforms import PackDetInputs +from mmdet.datasets.transforms import RandomAffine as MMDET_RandomAffine from mmdet.registry import TRANSFORMS from mmdet.structures.bbox.box_type import autocast_box_type from mmengine.structures import PixelData @@ -32,12 +35,14 @@ def transform(self, results: dict) -> dict: (0, len(results['flip_indices']), 2), dtype=np.float32) results['keypoints_visible'] = np.empty( (0, len(results['flip_indices'])), dtype=np.int32) + results['area'] = np.empty((0, ), dtype=np.int32) results['category_id'] = [] results['gt_bboxes'] = BBoxKeypoints( data=results['bbox'], keypoints=results['keypoints'], keypoints_visible=results['keypoints_visible'], + area=results['area'], flip_indices=results['flip_indices'], ) @@ -54,7 +59,8 @@ class PackDetPoseInputs(PackDetInputs): 'gt_bboxes_labels': 'labels', 'gt_masks': 'masks', 'gt_keypoints': 'keypoints', - 'gt_keypoints_visible': 'keypoints_visible' + 'gt_keypoints_visible': 'keypoints_visible', + 'gt_area': 'area' } field_mapping_table = { 'gt_heatmaps': 'gt_heatmaps', @@ -71,6 +77,7 @@ def transform(self, results: dict) -> dict: results['gt_keypoints'] = results['gt_bboxes'].keypoints results['gt_keypoints_visible'] = results[ 'gt_bboxes'].keypoints_visible + results['gt_area'] = results['gt_bboxes'].area # pack fields gt_fields = None @@ -226,3 +233,44 @@ def transform(self, results: dict) -> Union[dict, None]: results[key] = results[key][keep] return results + + +@TRANSFORMS.register_module(force=True) +class RandomAffine(MMDET_RandomAffine): + + @cache_randomness + def _get_random_homography_matrix(self, height, width): + + # Center + center_matrix = np.eye(3, dtype=np.float32) + center_matrix[0, 2] = -width / 2 # x translation (pixels) + center_matrix[1, 2] = -height / 2 # y translation (pixels) + + # Rotation + rotation_degree = random.uniform(-self.max_rotate_degree, + self.max_rotate_degree) + rotation_matrix = self._get_rotation_matrix(rotation_degree) + + # Scaling + scaling_ratio = random.uniform(self.scaling_ratio_range[0], + self.scaling_ratio_range[1]) + scaling_matrix = self._get_scaling_matrix(scaling_ratio) + + # Shear + x_degree = random.uniform(-self.max_shear_degree, + self.max_shear_degree) + y_degree = random.uniform(-self.max_shear_degree, + self.max_shear_degree) + shear_matrix = self._get_shear_matrix(x_degree, y_degree) + + # Translation + trans_x = random.uniform(0.5 - self.max_translate_ratio, + 0.5 + self.max_translate_ratio) * width + trans_y = random.uniform(0.5 - self.max_translate_ratio, + 0.5 + self.max_translate_ratio) * height + translate_matrix = self._get_translation_matrix(trans_x, trans_y) + + warp_matrix = ( + translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix + @ center_matrix) + return warp_matrix diff --git a/projects/petr/models/losses.py b/projects/petr/models/losses.py index 279e6f7c6f..1cf4d9a4a8 100644 --- a/projects/petr/models/losses.py +++ b/projects/petr/models/losses.py @@ -54,7 +54,7 @@ def compute_oks(self, output: Tensor, target: Tensor, target_weights: Tensor, - bboxes: Optional[Tensor] = None) -> Tensor: + area: Optional[Tensor] = None) -> Tensor: """Calculates the OKS metric. Args: @@ -79,8 +79,7 @@ def compute_oks(self, if sigmas.device != dist.device: sigmas = sigmas.to(dist.device) dist = dist / (sigmas * 2) - if bboxes is not None: - area = torch.prod(bboxes[..., 2:] - bboxes[..., :2], dim=-1) * 0.53 + if area is not None: dist = dist / area.pow(0.5).clip(min=1e-8).unsqueeze(-1) return (torch.exp(-dist.pow(2) / 2) * target_weights).sum( diff --git a/projects/petr/models/match_costs.py b/projects/petr/models/match_costs.py index 2e17bb4087..380f3f9657 100644 --- a/projects/petr/models/match_costs.py +++ b/projects/petr/models/match_costs.py @@ -91,20 +91,11 @@ def __call__(self, """ # Extract keypoints and bounding boxes - pred_keypoints = pred_instances.keypoints - gt_keypoints = gt_instances.keypoints - gt_bboxes = gt_instances.bboxes + pred_keypoints = pred_instances.keypoints.unsqueeze(1) + gt_keypoints = gt_instances.keypoints.unsqueeze(0) gt_keypoints_visible = gt_instances.keypoints_visible - - # Normalize keypoints and bounding boxes based on image shape - img_h, img_w = img_meta['img_shape'] - factor = gt_keypoints.new_tensor([img_w, img_h]).reshape(1, 1, 2) - gt_keypoints = (gt_keypoints / factor).unsqueeze(0) - pred_keypoints = (pred_keypoints / factor).unsqueeze(1) - gt_bboxes = (gt_bboxes.reshape(-1, 2, 2) / factor).reshape(1, -1, 4) - + gt_areas = gt_instances.area.reshape(1, gt_keypoints.size(1)) # Calculate OKS cost kpt_cost = self.compute_oks(pred_keypoints, gt_keypoints, - gt_keypoints_visible, gt_bboxes) - kpt_cost = -kpt_cost + gt_keypoints_visible, gt_areas) return kpt_cost * self.weight diff --git a/projects/petr/models/petr_head.py b/projects/petr/models/petr_head.py index ac4ecd3c5c..8aa4c2d950 100644 --- a/projects/petr/models/petr_head.py +++ b/projects/petr/models/petr_head.py @@ -304,14 +304,14 @@ def loss_by_feat_single(self, cls_reg_targets = self.get_targets(cls_scores_list, kpt_preds_list, batch_gt_instances, batch_img_metas) (labels_list, label_weights_list, bbox_targets_list, kpt_targets_list, - kpt_weights_list, num_total_pos, num_total_neg) = cls_reg_targets + kpt_weights_list, area_targets_list, num_total_pos, + num_total_neg) = cls_reg_targets labels = torch.cat(labels_list, 0) # [bs*300] label_weights = torch.cat(label_weights_list, 0) # [bs*300] (all 1) - bbox_targets = torch.cat(bbox_targets_list, - 0) # [bs*300, 4] (normalized) kpt_targets = torch.cat(kpt_targets_list, 0) # [bs*300, 17, 2] (normalized) kpt_weights = torch.cat(kpt_weights_list, 0) # [bs*300, 17] + area_targets = torch.cat(area_targets_list, 0) # [bs*300] # keypoint regression loss kpt_preds = kpt_preds.reshape(-1, self.num_keypoints, 2) @@ -354,11 +354,10 @@ def loss_by_feat_single(self, pos_kpt_preds = kpt_preds[pos_mask] * factors pos_kpt_targets = kpt_targets[pos_mask] * factors pos_kpt_weights = kpt_weights[pos_mask] - pos_bbox_targets = (bbox_targets[pos_mask].reshape(-1, 2, 2) * - factors).flatten(-2) + pos_area_targets = area_targets[pos_mask] loss_oks = self.loss_oks_aux(pos_kpt_preds, pos_kpt_targets, - pos_kpt_weights, pos_bbox_targets) + pos_kpt_weights, pos_area_targets) else: loss_oks = torch.zeros_like(loss_kpt) @@ -370,16 +369,15 @@ def loss_refined_kpts(self, kpt_preds: Tensor, feature level.""" # kpt_preds [num_selected, num_keypoints, 2] - bbox_targets_list = self._target_buffer['bbox_targets_list'] kpt_targets_list = self._target_buffer['kpt_targets_list'] kpt_weights_list = self._target_buffer['kpt_weights_list'] + area_targets_list = self._target_buffer['area_targets_list'] num_queries = len(kpt_targets_list[0]) - bbox_targets = torch.cat(bbox_targets_list, - 0).contiguous() # [bs*300, 4] (normalized) kpt_targets = torch.cat(kpt_targets_list, 0).contiguous() # [bs*300, 17, 2] (normalized) kpt_weights = torch.cat(kpt_weights_list, 0).contiguous() # [bs*300, 17] + area_targets = torch.cat(area_targets_list, 0).contiguous() # [bs*300] pos_mask = (kpt_weights.sum(-1) > 0).contiguous() pos_inds = (pos_mask.nonzero()).div( @@ -413,11 +411,12 @@ def loss_refined_kpts(self, kpt_preds: Tensor, pos_kpt_targets = pos_kpt_targets * factors pos_kpt_weights = kpt_weights[pos_mask] - pos_bbox_targets = (bbox_targets[pos_mask].reshape(-1, 2, 2) * - factors).flatten(-2) + # pos_bbox_targets = (bbox_targets[pos_mask].reshape(-1, 2, 2) * + # factors).flatten(-2) + pos_area_targets = area_targets[pos_mask] loss_oks = self.loss_oks_aux(pos_kpt_preds, pos_kpt_targets, - pos_kpt_weights, pos_bbox_targets) + pos_kpt_weights, pos_area_targets) else: loss_oks = torch.zeros_like(loss_kpt) @@ -480,7 +479,7 @@ def get_targets(self, - num_total_neg (int): Number of negative samples in all images. """ (labels_list, label_weights_list, bbox_targets_list, kpt_targets_list, - kpt_weights_list, pos_inds_list, + kpt_weights_list, area_targets_list, pos_inds_list, neg_inds_list) = multi_apply(self._get_targets_single, cls_scores_list, kpt_preds_list, batch_gt_instances, batch_img_metas) @@ -493,12 +492,13 @@ def get_targets(self, self._target_buffer['bbox_targets_list'] = bbox_targets_list self._target_buffer['kpt_targets_list'] = kpt_targets_list self._target_buffer['kpt_weights_list'] = kpt_weights_list + self._target_buffer['area_targets_list'] = area_targets_list self._target_buffer['num_total_pos'] = num_total_pos self._target_buffer['num_total_neg'] = num_total_neg return (labels_list, label_weights_list, bbox_targets_list, - kpt_targets_list, kpt_weights_list, num_total_pos, - num_total_neg) + kpt_targets_list, kpt_weights_list, area_targets_list, + num_total_pos, num_total_neg) def _get_targets_single(self, cls_score: Tensor, kpt_pred: Tensor, gt_instances: InstanceData, @@ -545,6 +545,7 @@ def _get_targets_single(self, cls_score: Tensor, kpt_pred: Tensor, gt_keypoints_visible = gt_instances.keypoints_visible gt_labels = gt_instances.labels gt_bboxes = gt_instances.bboxes + gt_areas = gt_instances.area pos_inds = torch.nonzero( assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() neg_inds = torch.nonzero( @@ -573,6 +574,8 @@ def _get_targets_single(self, cls_score: Tensor, kpt_pred: Tensor, bbox_targets[pos_inds] = pos_gt_bboxes.reshape( *pos_gt_bboxes.shape[:-1], 2, 2) / factor bbox_targets = bbox_targets.flatten(-2) + area_targets = kpt_pred.new_zeros(kpt_pred.shape[0]) + area_targets[pos_inds] = gt_areas[pos_assigned_gt_inds.long()].float() return (labels, label_weights, bbox_targets, kpt_targets, kpt_weights, - pos_inds, neg_inds) + area_targets, pos_inds, neg_inds) From 264715ffda0ec90cb6b33229be8c02373c7d6588 Mon Sep 17 00:00:00 2001 From: lupeng Date: Tue, 13 Jun 2023 21:14:48 +0800 Subject: [PATCH 10/12] update configs --- projects/petr/configs/petr_r101_8xb4-100e_coco.py | 5 ++++- projects/petr/configs/petr_r50_8xb3-100e_coco.py | 5 +++++ projects/petr/configs/petr_r50_8xb4-100e_coco.py | 11 +++++------ ...xb4-100e_coco.py => petr_swin-l_8xb2-100e_coco.py} | 7 +++++++ projects/petr/models/match_costs.py | 2 +- 5 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 projects/petr/configs/petr_r50_8xb3-100e_coco.py rename projects/petr/configs/{petr_swin-l_8xb4-100e_coco.py => petr_swin-l_8xb2-100e_coco.py} (68%) diff --git a/projects/petr/configs/petr_r101_8xb4-100e_coco.py b/projects/petr/configs/petr_r101_8xb4-100e_coco.py index 9ac326f147..a30bbd3dd7 100644 --- a/projects/petr/configs/petr_r101_8xb4-100e_coco.py +++ b/projects/petr/configs/petr_r101_8xb4-100e_coco.py @@ -1,4 +1,7 @@ _base_ = ['petr_r50_8xb4-100e_coco.py'] # model -model = dict(backbone=dict(depth=101)) +checkpoint = 'https://download.openmmlab.com/mmpose/mmdet_pretrained/' \ + 'deformable_detr_twostage_refine_r101_16x2_50e_coco-3186d66b_20230613.pth' + +model = dict(init_cfg=dict(checkpoint=checkpoint), backbone=dict(depth=101)) diff --git a/projects/petr/configs/petr_r50_8xb3-100e_coco.py b/projects/petr/configs/petr_r50_8xb3-100e_coco.py new file mode 100644 index 0000000000..e3c842c60f --- /dev/null +++ b/projects/petr/configs/petr_r50_8xb3-100e_coco.py @@ -0,0 +1,5 @@ +_base_ = ['petr_r50_8xb4-100e_coco.py'] + +auto_scale_lr = dict(base_batch_size=24) +train_dataloader = dict(batch_size=3) +optim_wrapper = dict(optimizer=dict(lr=0.00015)) diff --git a/projects/petr/configs/petr_r50_8xb4-100e_coco.py b/projects/petr/configs/petr_r50_8xb4-100e_coco.py index 72ee7dbd78..225cf982d2 100644 --- a/projects/petr/configs/petr_r50_8xb4-100e_coco.py +++ b/projects/petr/configs/petr_r50_8xb4-100e_coco.py @@ -132,8 +132,8 @@ gamma=2.0, alpha=0.25, loss_weight=2.0), - loss_reg=dict(type='L1Loss', loss_weight=80.0), - loss_reg_aux=dict(type='L1Loss', loss_weight=70.0), + loss_reg=dict(type='L1Loss', loss_weight=40.0), + loss_reg_aux=dict(type='L1Loss', loss_weight=35.0), loss_oks=dict( type='OksLoss', metainfo='configs/_base_/datasets/coco.py', @@ -168,10 +168,9 @@ dict( type='RandomAffine', max_rotate_degree=30.0, - # max_translate_ratio=0., - # scaling_ratio_range=(1., 1.), - # max_shear_degree=0., - scaling_ratio_range=(0.75, 1.0), + max_translate_ratio=0., + scaling_ratio_range=(1., 1.), + max_shear_degree=0., border_val=[103.53, 116.28, 123.675], ), dict(type='RandomFlip', prob=0.5), diff --git a/projects/petr/configs/petr_swin-l_8xb4-100e_coco.py b/projects/petr/configs/petr_swin-l_8xb2-100e_coco.py similarity index 68% rename from projects/petr/configs/petr_swin-l_8xb4-100e_coco.py rename to projects/petr/configs/petr_swin-l_8xb2-100e_coco.py index 5bc8ad8f1e..c9abc36dad 100644 --- a/projects/petr/configs/petr_swin-l_8xb4-100e_coco.py +++ b/projects/petr/configs/petr_swin-l_8xb2-100e_coco.py @@ -1,7 +1,12 @@ _base_ = ['petr_r50_8xb4-100e_coco.py'] # model + +checkpoint = 'https://download.openmmlab.com/mmpose/mmdet_pretrained/' \ + 'deformable_detr_twostage_refine_swin_16x1_50e_coco-95953bd1_20230613.pth' + model = dict( + init_cfg=dict(checkpoint=checkpoint), backbone=dict( _delete_=True, type='SwinTransformer', @@ -21,4 +26,6 @@ convert_weights=True), neck=dict(in_channels=[384, 768, 1536])) +auto_scale_lr = dict(base_batch_size=16) +train_dataloader = dict(batch_size=2) optim_wrapper = dict(optimizer=dict(lr=0.0001)) diff --git a/projects/petr/models/match_costs.py b/projects/petr/models/match_costs.py index 380f3f9657..de239c3b17 100644 --- a/projects/petr/models/match_costs.py +++ b/projects/petr/models/match_costs.py @@ -98,4 +98,4 @@ def __call__(self, # Calculate OKS cost kpt_cost = self.compute_oks(pred_keypoints, gt_keypoints, gt_keypoints_visible, gt_areas) - return kpt_cost * self.weight + return -kpt_cost * self.weight From ae83687b02821ae7e9539edc4eada9d9e645b806 Mon Sep 17 00:00:00 2001 From: lupeng Date: Tue, 13 Jun 2023 21:29:26 +0800 Subject: [PATCH 11/12] update readme --- projects/petr/README.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/projects/petr/README.md b/projects/petr/README.md index b8dae3511b..451ba8fd15 100644 --- a/projects/petr/README.md +++ b/projects/petr/README.md @@ -37,13 +37,14 @@ For more information on using the inferencer, please see [this document](https:/ Results on COCO val2017 -| Model | Backbone | Lr schd | mAP | AP50 | AP75 | AR | AR50 | Config | Download | -| :---: | :------: | :-----: | :--: | :-------------: | :-------------: | :--: | :-------------: | :----------------------------------------------: | :-----------------------------------------------------------------------------: | -| PETR | R-50 | 100e | 68.7 | 87.5 | 76.2 | 75.9 | 92.1 | [config](/configs/petr_r50_8xb4-100e_coco.py) | [Google Drive](https://drive.google.com/file/d/1HcwraqWdZ3CaGMQOJHY8exNem7UnFkfS/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1C0HbQWV7K-GHQE7q34nUZw?pwd=u798) | -| PETR | R-101 | 100e | 70.0 | 88.5 | 77.5 | 77.0 | 92.6 | [config](/configs/petr_r101_8xb4-100e_coco.py) | [Google Drive](https://drive.google.com/file/d/1O261Jrt4JRGlIKTmLtPy3AUruwX1hsDf/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1D5wqNP53KNOKKE5NnO2Dnw?pwd=keyn) | -| PETR | Swin-L | 100e | 73.0 | 90.7 | 80.9 | 80.1 | 94.5 | [config](/configs/petr_swin-l_8xb4-100e_coco.py) | [Google Drive](https://drive.google.com/file/d/1ujL0Gm5tPjweT0-gdDGkTc7xXrEt6gBP/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1X5Cdq75GosRCKqbHZTSpJQ?pwd=t9ea) | - -Currently, the PETR implemented in this project supports inference using the official checkpoint. However, the training accuracy is still not up to the results reported in the paper. We will continue to update this project after aligning the training accuracy. +| Model | mAP | AP50 | AP75 | AR | AR50 | Checkpoint | Log | +| :----------------------------------------------: | :--: | :-------------: | :-------------: | :--: | :-------------: | :---------------------------------------------------: | :--------------------------------------------: | +| [PETR-R50](/configs/petr_r50_8xb3-100e_coco.py) | 68.7 | 86.2 | 76.0 | 76.5 | 91.5 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/petr/petr_r50_8xb3-100e_coco-520803d9_20230613.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/petr/petr_r50_8xb3-100e_coco-20230613.json) | +| [PETR-R50](/configs/petr_r50_8xb4-100e_coco.py)\* | 68.7 | 87.5 | 76.2 | 75.9 | 92.1 | [Google Drive](https://drive.google.com/file/d/1HcwraqWdZ3CaGMQOJHY8exNem7UnFkfS/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1C0HbQWV7K-GHQE7q34nUZw?pwd=u798) | / | +| [PETR-R101](/configs/petr_r101_8xb4-100e_coco.py)\* | 70.0 | 88.5 | 77.5 | 77.0 | 92.6 | [Google Drive](https://drive.google.com/file/d/1O261Jrt4JRGlIKTmLtPy3AUruwX1hsDf/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1D5wqNP53KNOKKE5NnO2Dnw?pwd=keyn) | / | +| [PETR-Swin](/configs/petr_swin-l_8xb2-100e_coco.py)\* | 73.0 | 90.7 | 80.9 | 80.1 | 94.5 | [Google Drive](https://drive.google.com/file/d/1ujL0Gm5tPjweT0-gdDGkTc7xXrEt6gBP/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1X5Cdq75GosRCKqbHZTSpJQ?pwd=t9ea) | / | + +\* Indicates that the checkpoints are from the official repository. The training accuracy is still not up to the results reported in the paper. We will continue to update this project after aligning the training accuracy. ## Citation From ce02f2c37e6f4c104e38d6597512dc3d2f397fb4 Mon Sep 17 00:00:00 2001 From: lupeng Date: Tue, 13 Jun 2023 21:35:37 +0800 Subject: [PATCH 12/12] fix grammar error --- projects/petr/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/petr/README.md b/projects/petr/README.md index 451ba8fd15..6a38035aae 100644 --- a/projects/petr/README.md +++ b/projects/petr/README.md @@ -44,7 +44,7 @@ Results on COCO val2017 | [PETR-R101](/configs/petr_r101_8xb4-100e_coco.py)\* | 70.0 | 88.5 | 77.5 | 77.0 | 92.6 | [Google Drive](https://drive.google.com/file/d/1O261Jrt4JRGlIKTmLtPy3AUruwX1hsDf/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1D5wqNP53KNOKKE5NnO2Dnw?pwd=keyn) | / | | [PETR-Swin](/configs/petr_swin-l_8xb2-100e_coco.py)\* | 73.0 | 90.7 | 80.9 | 80.1 | 94.5 | [Google Drive](https://drive.google.com/file/d/1ujL0Gm5tPjweT0-gdDGkTc7xXrEt6gBP/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1X5Cdq75GosRCKqbHZTSpJQ?pwd=t9ea) | / | -\* Indicates that the checkpoints are from the official repository. The training accuracy is still not up to the results reported in the paper. We will continue to update this project after aligning the training accuracy. +\* Indicates that the checkpoints are sourced from the official repository. The training accuracy is still not up to the results reported in the paper. We will continue to update this project after aligning the training accuracy. ## Citation