diff --git a/projects/README.md b/projects/README.md
index 1089af8194..62126f9a85 100644
--- a/projects/README.md
+++ b/projects/README.md
@@ -61,3 +61,9 @@ We also provide some documentation listed below to help you get started:
- **What's next? Join the rank of *MMPose contributors* by creating a new project**!
+
+- **[:bulb:PETR](./petr)**: End-to-End Multi-Person Pose Estimation with Transformers
+
+
+
![](https://github.com/open-mmlab/mmpose/assets/26127467/ec7eb99d-8b8b-4c0d-9714-0ccd33a4f054)
+
diff --git a/projects/petr/README.md b/projects/petr/README.md
new file mode 100644
index 0000000000..6a38035aae
--- /dev/null
+++ b/projects/petr/README.md
@@ -0,0 +1,61 @@
+# PETR
+
+This project implements PETR (Pose Estimation with TRansformers), an end-to-end multi-person pose estimation framework introduced in the CVPR 2022 paper **End-to-End Multi-Person Pose Estimation with Transformers**. PETR is a novel, end-to-end multi-person pose estimation method that treats pose estimation as a hierarchical set prediction problem. By leveraging attention mechanisms, PETR can adaptively focus on features most relevant to target keypoints, thereby overcoming feature misalignment issues in pose estimation.
+
+![](https://github.com/open-mmlab/mmpose/assets/26127467/ec7eb99d-8b8b-4c0d-9714-0ccd33a4f054)
+
+## Usage
+
+### Prerequisites
+
+- Python 3.7 or higher
+- PyTorch 1.6 or higher
+- [MMEngine](https://github.com/open-mmlab/mmengine) v0.7.0 or higher
+- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0 or higher
+- [MMDetection](https://github.com/open-mmlab/mmdetection) v3.0.0 or higher
+- [MMPose](https://github.com/open-mmlab/mmpose) v1.0.0 or higher
+
+All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. **In `petr/` root directory**, run the following line to add the current directory to `PYTHONPATH`:
+
+```shell
+export PYTHONPATH=`pwd`:$PYTHONPATH
+```
+
+### Inference
+
+Users can apply PETR models to estimate human poses using the inferencer found in the MMPose core package. Use the command below:
+
+```shell
+python demo/inferencer_demo.py $INPUTS \
+ --pose2d $CONFIG --pose2d-weights $CHECKPOINT --scope mmdet \
+ [--show] [--vis-out-dir $VIS_OUT_DIR] [--pred-out-dir $PRED_OUT_DIR]
+```
+
+For more information on using the inferencer, please see [this document](https://mmpose.readthedocs.io/en/latest/user_guides/inference.html#out-of-the-box-inferencer).
+
+### Results
+
+Results on COCO val2017
+
+| Model | mAP | AP50 | AP75 | AR | AR50 | Checkpoint | Log |
+| :----------------------------------------------: | :--: | :-------------: | :-------------: | :--: | :-------------: | :---------------------------------------------------: | :--------------------------------------------: |
+| [PETR-R50](/configs/petr_r50_8xb3-100e_coco.py) | 68.7 | 86.2 | 76.0 | 76.5 | 91.5 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/petr/petr_r50_8xb3-100e_coco-520803d9_20230613.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/petr/petr_r50_8xb3-100e_coco-20230613.json) |
+| [PETR-R50](/configs/petr_r50_8xb4-100e_coco.py)\* | 68.7 | 87.5 | 76.2 | 75.9 | 92.1 | [Google Drive](https://drive.google.com/file/d/1HcwraqWdZ3CaGMQOJHY8exNem7UnFkfS/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1C0HbQWV7K-GHQE7q34nUZw?pwd=u798) | / |
+| [PETR-R101](/configs/petr_r101_8xb4-100e_coco.py)\* | 70.0 | 88.5 | 77.5 | 77.0 | 92.6 | [Google Drive](https://drive.google.com/file/d/1O261Jrt4JRGlIKTmLtPy3AUruwX1hsDf/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1D5wqNP53KNOKKE5NnO2Dnw?pwd=keyn) | / |
+| [PETR-Swin](/configs/petr_swin-l_8xb2-100e_coco.py)\* | 73.0 | 90.7 | 80.9 | 80.1 | 94.5 | [Google Drive](https://drive.google.com/file/d/1ujL0Gm5tPjweT0-gdDGkTc7xXrEt6gBP/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1X5Cdq75GosRCKqbHZTSpJQ?pwd=t9ea) | / |
+
+\* Indicates that the checkpoints are sourced from the official repository. The training accuracy is still not up to the results reported in the paper. We will continue to update this project after aligning the training accuracy.
+
+## Citation
+
+If this project benefits your work, please kindly consider citing the original papers:
+
+```bibtex
+@inproceedings{shi2022end,
+ title={End-to-end multi-person pose estimation with transformers},
+ author={Shi, Dahu and Wei, Xing and Li, Liangqi and Ren, Ye and Tan, Wenming},
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
+ pages={11069--11078},
+ year={2022}
+}
+```
diff --git a/projects/petr/configs/_base_/datasets b/projects/petr/configs/_base_/datasets
new file mode 120000
index 0000000000..8feca66d56
--- /dev/null
+++ b/projects/petr/configs/_base_/datasets
@@ -0,0 +1 @@
+../../../../configs/_base_/datasets
\ No newline at end of file
diff --git a/projects/petr/configs/_base_/default_runtime.py b/projects/petr/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000..fb9e38577c
--- /dev/null
+++ b/projects/petr/configs/_base_/default_runtime.py
@@ -0,0 +1,41 @@
+default_scope = 'mmdet'
+custom_imports = dict(imports=['models', 'datasets'])
+
+# hooks
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=50),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', interval=5, max_keep_ckpts=3),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='mmpose.PoseVisualizationHook', enable=False),
+)
+
+# multi-processing backend
+env_cfg = dict(
+ cudnn_benchmark=False,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'),
+)
+
+# visualizer
+vis_backends = [dict(type='LocalVisBackend')]
+visualizer = dict(
+ type='mmpose.PoseLocalVisualizer',
+ vis_backends=vis_backends,
+ name='visualizer')
+
+# logger
+log_processor = dict(
+ type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
+log_level = 'INFO'
+load_from = None
+resume = False
+
+# file I/O backend
+backend_args = dict(backend='local')
+
+# training/validation/testing progress
+train_cfg = dict()
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
diff --git a/projects/petr/configs/petr_r101_8xb4-100e_coco.py b/projects/petr/configs/petr_r101_8xb4-100e_coco.py
new file mode 100644
index 0000000000..a30bbd3dd7
--- /dev/null
+++ b/projects/petr/configs/petr_r101_8xb4-100e_coco.py
@@ -0,0 +1,7 @@
+_base_ = ['petr_r50_8xb4-100e_coco.py']
+
+# model
+checkpoint = 'https://download.openmmlab.com/mmpose/mmdet_pretrained/' \
+ 'deformable_detr_twostage_refine_r101_16x2_50e_coco-3186d66b_20230613.pth'
+
+model = dict(init_cfg=dict(checkpoint=checkpoint), backbone=dict(depth=101))
diff --git a/projects/petr/configs/petr_r50_8xb3-100e_coco.py b/projects/petr/configs/petr_r50_8xb3-100e_coco.py
new file mode 100644
index 0000000000..e3c842c60f
--- /dev/null
+++ b/projects/petr/configs/petr_r50_8xb3-100e_coco.py
@@ -0,0 +1,5 @@
+_base_ = ['petr_r50_8xb4-100e_coco.py']
+
+auto_scale_lr = dict(base_batch_size=24)
+train_dataloader = dict(batch_size=3)
+optim_wrapper = dict(optimizer=dict(lr=0.00015))
diff --git a/projects/petr/configs/petr_r50_8xb4-100e_coco.py b/projects/petr/configs/petr_r50_8xb4-100e_coco.py
new file mode 100644
index 0000000000..225cf982d2
--- /dev/null
+++ b/projects/petr/configs/petr_r50_8xb4-100e_coco.py
@@ -0,0 +1,263 @@
+_base_ = ['./_base_/default_runtime.py']
+
+# learning policy
+max_epochs = 100
+train_cfg = dict(
+ type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=5)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+
+param_scheduler = [
+ dict(
+ type='MultiStepLR',
+ begin=0,
+ end=max_epochs,
+ by_epoch=True,
+ milestones=[80],
+ gamma=0.1)
+]
+
+# NOTE: `auto_scale_lr` is for automatically scaling LR,
+auto_scale_lr = dict(base_batch_size=32)
+
+# optimizer
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001),
+ clip_grad=dict(max_norm=0.1, norm_type=2),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'backbone': dict(lr_mult=0.1),
+ 'sampling_offsets': dict(lr_mult=0.1),
+ 'reference_points': dict(lr_mult=0.1)
+ }))
+
+# model
+num_keypoints = 17
+checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/deformable_' \
+ 'detr/deformable_detr_twostage_refine_r50_16x2_50e_coco/deformable_' \
+ 'detr_twostage_refine_r50_16x2_50e_coco_20210419_220613-9d28ab72.pth'
+model = dict(
+ type='PETR',
+ num_queries=300,
+ num_feature_levels=4,
+ num_keypoints=num_keypoints,
+ with_box_refine=True,
+ as_two_stage=True,
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint=checkpoint,
+ ),
+ data_preprocessor=dict(
+ type='DetDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ pad_size_divisor=1),
+ backbone=dict(
+ type='ResNet',
+ depth=50,
+ num_stages=4,
+ out_indices=(1, 2, 3),
+ frozen_stages=1,
+ norm_cfg=dict(type='BN', requires_grad=False),
+ norm_eval=True,
+ style='pytorch'),
+ neck=dict(
+ type='ChannelMapper',
+ in_channels=[512, 1024, 2048],
+ kernel_size=1,
+ out_channels=256,
+ act_cfg=None,
+ norm_cfg=dict(type='GN', num_groups=32),
+ num_outs=4),
+ encoder=dict( # DeformableDetrTransformerEncoder
+ num_layers=6,
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
+ embed_dims=256,
+ batch_first=True),
+ ffn_cfg=dict(
+ embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))),
+ decoder=dict( # PetrTransformerDecoder
+ num_layers=3,
+ num_keypoints=num_keypoints,
+ return_intermediate=True,
+ layer_cfg=dict( # PetrTransformerDecoderLayer
+ self_attn_cfg=dict( # MultiheadAttention
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.1,
+ batch_first=True),
+ cross_attn_cfg=dict( # MultiScaleDeformablePoseAttention
+ embed_dims=256,
+ num_points=num_keypoints,
+ batch_first=True),
+ ffn_cfg=dict(
+ embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)),
+ post_norm_cfg=None),
+ hm_encoder=dict( # DeformableDetrTransformerEncoder
+ num_layers=1,
+ layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
+ self_attn_cfg=dict( # MultiScaleDeformableAttention
+ embed_dims=256,
+ num_levels=1,
+ batch_first=True),
+ ffn_cfg=dict(
+ embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))),
+ kpt_decoder=dict( # DeformableDetrTransformerDecoder
+ num_layers=2,
+ return_intermediate=True,
+ layer_cfg=dict( # DetrTransformerDecoderLayer
+ self_attn_cfg=dict( # MultiheadAttention
+ embed_dims=256,
+ num_heads=8,
+ dropout=0.1,
+ batch_first=True),
+ cross_attn_cfg=dict( # MultiScaleDeformableAttention
+ embed_dims=256,
+ im2col_step=128),
+ ffn_cfg=dict(
+ embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)),
+ post_norm_cfg=None),
+ positional_encoding=dict(num_feats=128, normalize=True, offset=-0.5),
+ bbox_head=dict(
+ type='PETRHead',
+ num_classes=1,
+ num_keypoints=num_keypoints,
+ sync_cls_avg_factor=True,
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=2.0),
+ loss_reg=dict(type='L1Loss', loss_weight=40.0),
+ loss_reg_aux=dict(type='L1Loss', loss_weight=35.0),
+ loss_oks=dict(
+ type='OksLoss',
+ metainfo='configs/_base_/datasets/coco.py',
+ loss_weight=3.0),
+ loss_oks_aux=dict(
+ type='OksLoss',
+ metainfo='configs/_base_/datasets/coco.py',
+ loss_weight=2.0),
+ loss_hm=dict(type='mmpose.FocalHeatmapLoss', loss_weight=4.0),
+ ),
+ # training and testing settings
+ train_cfg=dict(
+ assigner=dict(
+ type='HungarianAssigner',
+ match_costs=[
+ dict(type='FocalLossCost', weight=2.0),
+ dict(type='KptL1Cost', weight=70.0),
+ dict(
+ type='OksCost',
+ metainfo='configs/_base_/datasets/coco.py',
+ weight=7.0)
+ ])),
+ test_cfg=dict(
+ max_per_img=100,
+ score_thr=0.0,
+ ))
+
+train_pipeline = [
+ dict(type='mmpose.LoadImage', backend_args=_base_.backend_args),
+ dict(type='PoseToDetConverter'),
+ dict(type='PhotoMetricDistortion'),
+ dict(
+ type='RandomAffine',
+ max_rotate_degree=30.0,
+ max_translate_ratio=0.,
+ scaling_ratio_range=(1., 1.),
+ max_shear_degree=0.,
+ border_val=[103.53, 116.28, 123.675],
+ ),
+ dict(type='RandomFlip', prob=0.5),
+ dict(
+ type='RandomChoice',
+ transforms=[
+ [
+ dict(
+ type='RandomChoiceResize',
+ scales=list(zip(range(400, 1401, 8), (1400, ) * 126)),
+ keep_ratio=True)
+ ],
+ [
+ dict(
+ type='RandomChoiceResize',
+ # The radio of all image in train dataset < 7
+ # follow the original implement
+ scales=[(400, 4200), (500, 4200), (600, 4200)],
+ keep_ratio=True),
+ dict(
+ type='RandomCrop',
+ crop_type='absolute_range',
+ crop_size=(384, 600),
+ allow_negative_crop=True),
+ dict(
+ type='RandomChoiceResize',
+ scales=list(zip(range(400, 1401, 8), (1400, ) * 126)),
+ keep_ratio=True)
+ ]
+ ]),
+ dict(type='FilterDetPoseAnnotations', keep_empty=False),
+ dict(type='GenerateHeatmap'),
+ dict(
+ type='PackDetPoseInputs',
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape'))
+]
+
+test_pipeline = [
+ dict(type='mmpose.LoadImage', backend_args=_base_.backend_args),
+ dict(type='PoseToDetConverter'),
+ dict(type='Resize', scale=(1333, 800), keep_ratio=True),
+ dict(
+ type='PackDetPoseInputs',
+ meta_keys=('id', 'img_id', 'img_path', 'ori_shape', 'img_shape',
+ 'scale_factor', 'flip_indices'))
+]
+
+dataset_type = 'CocoDataset'
+data_mode = 'bottomup'
+data_root = 'data/coco/'
+
+train_dataloader = dict(
+ batch_size=4,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ batch_sampler=dict(type='AspectRatioBatchSampler'),
+ dataset=dict(
+ type=dataset_type,
+ data_mode=data_mode,
+ data_root=data_root,
+ ann_file='annotations/person_keypoints_train2017.json',
+ data_prefix=dict(img='train2017/'),
+ filter_cfg=dict(filter_empty_gt=True, min_size=32),
+ pipeline=train_pipeline))
+
+val_dataloader = dict(
+ batch_size=1,
+ num_workers=2,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type=dataset_type,
+ data_mode=data_mode,
+ data_root=data_root,
+ ann_file='annotations/person_keypoints_val2017.json',
+ data_prefix=dict(img='val2017/'),
+ test_mode=True,
+ pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+val_evaluator = dict(
+ type='mmpose.CocoMetric',
+ ann_file=data_root + 'annotations/person_keypoints_val2017.json',
+ nms_mode='none',
+ score_mode='bbox')
+test_evaluator = val_evaluator
+
+default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))
diff --git a/projects/petr/configs/petr_swin-l_8xb2-100e_coco.py b/projects/petr/configs/petr_swin-l_8xb2-100e_coco.py
new file mode 100644
index 0000000000..c9abc36dad
--- /dev/null
+++ b/projects/petr/configs/petr_swin-l_8xb2-100e_coco.py
@@ -0,0 +1,31 @@
+_base_ = ['petr_r50_8xb4-100e_coco.py']
+
+# model
+
+checkpoint = 'https://download.openmmlab.com/mmpose/mmdet_pretrained/' \
+ 'deformable_detr_twostage_refine_swin_16x1_50e_coco-95953bd1_20230613.pth'
+
+model = dict(
+ init_cfg=dict(checkpoint=checkpoint),
+ backbone=dict(
+ _delete_=True,
+ type='SwinTransformer',
+ embed_dims=192,
+ depths=[2, 2, 18, 2],
+ num_heads=[6, 12, 24, 48],
+ window_size=7,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.5,
+ patch_norm=True,
+ out_indices=(1, 2, 3),
+ with_cp=False,
+ convert_weights=True),
+ neck=dict(in_channels=[384, 768, 1536]))
+
+auto_scale_lr = dict(base_batch_size=16)
+train_dataloader = dict(batch_size=2)
+optim_wrapper = dict(optimizer=dict(lr=0.0001))
diff --git a/projects/petr/datasets/__init__.py b/projects/petr/datasets/__init__.py
new file mode 100644
index 0000000000..69bae9de53
--- /dev/null
+++ b/projects/petr/datasets/__init__.py
@@ -0,0 +1,3 @@
+from .bbox_keypoint_structure import * # noqa
+from .coco_dataset import * # noqa
+from .transforms import * # noqa
diff --git a/projects/petr/datasets/bbox_keypoint_structure.py b/projects/petr/datasets/bbox_keypoint_structure.py
new file mode 100644
index 0000000000..d53df0e3f4
--- /dev/null
+++ b/projects/petr/datasets/bbox_keypoint_structure.py
@@ -0,0 +1,301 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from copy import deepcopy
+from typing import List, Optional, Sequence, Tuple, Type, TypeVar, Union
+
+import numpy as np
+import torch
+from mmdet.structures.bbox import HorizontalBoxes
+from torch import Tensor
+
+DeviceType = Union[str, torch.device]
+T = TypeVar('T')
+IndexType = Union[slice, int, list, torch.LongTensor, torch.cuda.LongTensor,
+ torch.BoolTensor, torch.cuda.BoolTensor, np.ndarray]
+
+
+class BBoxKeypoints(HorizontalBoxes):
+ """The BBoxKeypoints class is a combination of bounding boxes and keypoints
+ representation. The box format used in BBoxKeypoints is the same as
+ HorizontalBoxes.
+
+ Args:
+ data (Tensor or np.ndarray): The box data with shape of
+ (N, 4).
+ keypoints (Tensor or np.ndarray): The keypoint data with shape of
+ (N, K, 2).
+ keypoints_visible (Tensor or np.ndarray): The visibility of keypoints
+ with shape of (N, K).
+ dtype (torch.dtype, Optional): data type of boxes. Defaults to None.
+ device (str or torch.device, Optional): device of boxes.
+ Default to None.
+ clone (bool): Whether clone ``boxes`` or not. Defaults to True.
+ mode (str, Optional): the mode of boxes. If it is 'cxcywh', the
+ `data` will be converted to 'xyxy' mode. Defaults to None.
+ flip_indices (list, Optional): The indices of keypoints when the
+ images is flipped. Defaults to None.
+
+ Notes:
+ N: the number of instances.
+ K: the number of keypoints.
+ """
+
+ def __init__(self,
+ data: Union[Tensor, np.ndarray],
+ keypoints: Union[Tensor, np.ndarray],
+ keypoints_visible: Union[Tensor, np.ndarray],
+ area: Union[Tensor, np.ndarray],
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[DeviceType] = None,
+ clone: bool = True,
+ in_mode: Optional[str] = None,
+ flip_indices: Optional[List] = None) -> None:
+
+ super().__init__(
+ data=data,
+ dtype=dtype,
+ device=device,
+ clone=clone,
+ in_mode=in_mode)
+
+ assert len(data) == len(keypoints)
+ assert len(data) == len(keypoints_visible)
+ assert len(data) == len(area)
+
+ assert keypoints.ndim == 3
+ assert keypoints_visible.ndim == 2
+
+ keypoints = torch.as_tensor(keypoints)
+ keypoints_visible = torch.as_tensor(keypoints_visible)
+ area = torch.as_tensor(area)
+
+ if device is not None:
+ keypoints = keypoints.to(device=device)
+ keypoints_visible = keypoints_visible.to(device=device)
+ area = area.to(device=device)
+
+ if clone:
+ keypoints = keypoints.clone()
+ keypoints_visible = keypoints_visible.clone()
+ area = area.clone()
+
+ self.keypoints = keypoints
+ self.keypoints_visible = keypoints_visible
+ self.area = area
+ self.flip_indices = flip_indices
+
+ def flip_(self,
+ img_shape: Tuple[int, int],
+ direction: str = 'horizontal') -> None:
+ """Flip boxes & kpts horizontally in-place.
+
+ Args:
+ img_shape (Tuple[int, int]): A tuple of image height and width.
+ direction (str): Flip direction, options are "horizontal",
+ "vertical" and "diagonal". Defaults to "horizontal"
+ """
+ assert direction == 'horizontal'
+ super().flip_(img_shape, direction)
+ self.keypoints[..., 0] = img_shape[1] - self.keypoints[..., 0]
+ self.keypoints = self.keypoints[:, self.flip_indices]
+ self.keypoints_visible = self.keypoints_visible[:, self.flip_indices]
+
+ def translate_(self, distances: Tuple[float, float]) -> None:
+ """Translate boxes and keypoints in-place.
+
+ Args:
+ distances (Tuple[float, float]): translate distances. The first
+ is horizontal distance and the second is vertical distance.
+ """
+ boxes = self.tensor
+ assert len(distances) == 2
+ self.tensor = boxes + boxes.new_tensor(distances).repeat(2)
+ distances = self.keypoints.new_tensor(distances).reshape(1, 1, 2)
+ self.keypoints = self.keypoints + distances
+
+ def rescale_(self, scale_factor: Tuple[float, float]) -> None:
+ """Rescale boxes & keypoints w.r.t. rescale_factor in-place.
+
+ Note:
+ Both ``rescale_`` and ``resize_`` will enlarge or shrink boxes
+ w.r.t ``scale_facotr``. The difference is that ``resize_`` only
+ changes the width and the height of boxes, but ``rescale_`` also
+ rescales the box centers simultaneously.
+
+ Args:
+ scale_factor (Tuple[float, float]): factors for scaling boxes.
+ The length should be 2.
+ """
+ boxes = self.tensor
+ assert len(scale_factor) == 2
+
+ self.area = self.area * scale_factor[0] * scale_factor[1]
+ self.tensor = boxes * boxes.new_tensor(scale_factor).repeat(2)
+ scale_factor = self.keypoints.new_tensor(scale_factor).reshape(1, 1, 2)
+ self.keypoints = self.keypoints * scale_factor
+
+ def clip_(self, img_shape: Tuple[int, int]) -> None:
+ """Clip bounding boxes and set invisible keypoints outside the image
+ boundary in-place.
+
+ Args:
+ img_shape (Tuple[int, int]): A tuple of image height and width.
+ """
+ boxes = self.tensor
+ boxes[..., 0::2] = boxes[..., 0::2].clamp(0, img_shape[1])
+ boxes[..., 1::2] = boxes[..., 1::2].clamp(0, img_shape[0])
+
+ kpt_outside = torch.logical_or(
+ torch.logical_or(self.keypoints[..., 0] < 0,
+ self.keypoints[..., 1] < 0),
+ torch.logical_or(self.keypoints[..., 0] > img_shape[1],
+ self.keypoints[..., 1] > img_shape[0]))
+ self.keypoints_visible[kpt_outside] *= 0
+
+ def project_(self, homography_matrix: Union[Tensor, np.ndarray]) -> None:
+ """Geometrically transform bounding boxes and keypoints in-place using
+ a homography matrix.
+
+ Args:
+ homography_matrix (Tensor or np.ndarray): A 3x3 tensor or ndarray
+ representing the homography matrix for the transformation.
+ """
+
+ boxes = self.tensor
+ if isinstance(homography_matrix, np.ndarray):
+ homography_matrix = boxes.new_tensor(homography_matrix)
+
+ # Convert boxes to corners in homogeneous coordinates
+ corners = self.hbox2corner(boxes)
+ corners = torch.cat(
+ [corners, corners.new_ones(*corners.shape[:-1], 1)], dim=-1)
+
+ # Convert keypoints to homogeneous coordinates
+ keypoints = torch.cat([
+ self.keypoints,
+ self.keypoints.new_ones(*self.keypoints.shape[:-1], 1)
+ ],
+ dim=-1)
+
+ # Transpose corners and keypoints for matrix multiplication
+ corners_T = torch.transpose(corners, -1, -2)
+ keypoints_T = torch.transpose(keypoints, -1, 0).contiguous().flatten(1)
+
+ # Apply homography matrix to corners and keypoints
+ corners_T = torch.matmul(homography_matrix, corners_T)
+ keypoints_T = torch.matmul(homography_matrix, keypoints_T)
+
+ # Transpose back to original shape
+ corners = torch.transpose(corners_T, -1, -2)
+ keypoints_T = keypoints_T.reshape(3, self.keypoints.shape[1], -1)
+ keypoints = torch.transpose(keypoints_T, -1, 0).contiguous()
+
+ # Convert corners and keypoints back to non-homogeneous coordinates
+ corners = corners[..., :2] / corners[..., 2:3]
+ keypoints = keypoints[..., :2] / keypoints[..., 2:3]
+
+ # Convert corners back to bounding boxes and update object attributes
+ self.tensor = self.corner2hbox(corners)
+ self.keypoints = keypoints
+
+ @classmethod
+ def cat(cls: Type[T], box_list: Sequence[T], dim: int = 0) -> T:
+ """Cancatenates an instance list into one single instance. Similar to
+ ``torch.cat``.
+
+ Args:
+ box_list (Sequence[T]): A sequence of instances.
+ dim (int): The dimension over which the box and keypoint are
+ concatenated. Defaults to 0.
+
+ Returns:
+ T: Concatenated instance.
+ """
+ assert isinstance(box_list, Sequence)
+ if len(box_list) == 0:
+ raise ValueError('box_list should not be a empty list.')
+
+ assert dim == 0
+ assert all(isinstance(boxes, cls) for boxes in box_list)
+
+ th_box_list = torch.cat([boxes.tensor for boxes in box_list], dim=dim)
+ th_kpt_list = torch.cat([boxes.keypoints for boxes in box_list],
+ dim=dim)
+ th_kpt_vis_list = torch.cat(
+ [boxes.keypoints_visible for boxes in box_list], dim=dim)
+ th_area_list = torch.cat([boxes.area for boxes in box_list], dim=dim)
+ flip_indices = box_list[0].flip_indices
+ return cls(
+ th_box_list,
+ th_kpt_list,
+ th_kpt_vis_list,
+ th_area_list,
+ clone=False,
+ flip_indices=flip_indices)
+
+ def __getitem__(self: T, index: IndexType) -> T:
+ """Rewrite getitem to protect the last dimension shape."""
+ boxes = self.tensor
+ if isinstance(index, np.ndarray):
+ index = torch.as_tensor(index, device=self.device)
+ if isinstance(index, Tensor) and index.dtype == torch.bool:
+ assert index.dim() < boxes.dim()
+ elif isinstance(index, tuple):
+ assert len(index) < boxes.dim()
+ # `Ellipsis`(...) is commonly used in index like [None, ...].
+ # When `Ellipsis` is in index, it must be the last item.
+ if Ellipsis in index:
+ assert index[-1] is Ellipsis
+
+ boxes = boxes[index]
+ keypoints = self.keypoints[index]
+ keypoints_visible = self.keypoints_visible[index]
+ area = self.area[index]
+ if boxes.dim() == 1:
+ boxes = boxes.reshape(1, -1)
+ keypoints = keypoints.reshape(1, -1, 2)
+ keypoints_visible = keypoints_visible.reshape(1, -1)
+ area = area.reshape(1, )
+ return type(self)(
+ boxes,
+ keypoints,
+ keypoints_visible,
+ area,
+ flip_indices=self.flip_indices,
+ clone=False)
+
+ @property
+ def num_keypoints(self) -> Tensor:
+ """Compute the number of visible keypoints for each object."""
+ return self.keypoints_visible.sum(dim=1).int()
+
+ def __deepcopy__(self, memo):
+ """Only clone the tensors when applying deepcopy."""
+ cls = self.__class__
+ other = cls.__new__(cls)
+ memo[id(self)] = other
+ other.tensor = self.tensor.clone()
+ other.keypoints = self.keypoints.clone()
+ other.keypoints_visible = self.keypoints_visible.clone()
+ other.area = self.area.clone()
+ other.flip_indices = deepcopy(self.flip_indices)
+ return other
+
+ def clone(self: T) -> T:
+ """Reload ``clone`` for tensors."""
+ return type(self)(
+ self.tensor,
+ self.keypoints,
+ self.keypoints_visible,
+ self.area,
+ flip_indices=self.flip_indices,
+ clone=True)
+
+ def to(self: T, *args, **kwargs) -> T:
+ """Reload ``to`` for tensors."""
+ return type(self)(
+ self.tensor.to(*args, **kwargs),
+ self.keypoints.to(*args, **kwargs),
+ self.keypoints_visible.to(*args, **kwargs),
+ self.area.to(*args, **kwargs),
+ flip_indices=self.flip_indices,
+ clone=False)
diff --git a/projects/petr/datasets/coco_dataset.py b/projects/petr/datasets/coco_dataset.py
new file mode 100644
index 0000000000..23da016f1c
--- /dev/null
+++ b/projects/petr/datasets/coco_dataset.py
@@ -0,0 +1,140 @@
+import copy
+from itertools import filterfalse, groupby
+from typing import Dict, List, Optional
+
+import numpy as np
+from mmdet.registry import DATASETS
+
+from mmpose.datasets import CocoDataset as MMPoseCocoDataset
+
+
+@DATASETS.register_module(force=True)
+class CocoDataset(MMPoseCocoDataset):
+
+ def parse_data_info(self, raw_data_info: dict) -> Optional[dict]:
+ """Parse raw COCO annotation of an instance. Compared with original
+ implementation, this method add image width and height in `data_info`
+
+ Args:
+ raw_data_info (dict): Raw data information loaded from
+ ``ann_file``. It should have following contents:
+
+ - ``'raw_ann_info'``: Raw annotation of an instance
+ - ``'raw_img_info'``: Raw information of the image that
+ contains the instance
+
+ Returns:
+ dict | None: Parsed instance annotation
+ """
+
+ ann = raw_data_info['raw_ann_info']
+ img = raw_data_info['raw_img_info']
+
+ # filter invalid instance
+ if 'bbox' not in ann or 'keypoints' not in ann:
+ return None
+
+ img_w, img_h = img['width'], img['height']
+
+ # get bbox in shape [1, 4], formatted as xywh
+ x, y, w, h = ann['bbox']
+ x1 = np.clip(x, 0, img_w - 1)
+ y1 = np.clip(y, 0, img_h - 1)
+ x2 = np.clip(x + w, 0, img_w - 1)
+ y2 = np.clip(y + h, 0, img_h - 1)
+
+ bbox = np.array([x1, y1, x2, y2], dtype=np.float32).reshape(1, 4)
+
+ # keypoints in shape [1, K, 2] and keypoints_visible in [1, K]
+ _keypoints = np.array(
+ ann['keypoints'], dtype=np.float32).reshape(1, -1, 3)
+ keypoints = _keypoints[..., :2]
+ keypoints_visible = np.minimum(1, _keypoints[..., 2])
+
+ if 'num_keypoints' in ann:
+ num_keypoints = ann['num_keypoints']
+ else:
+ num_keypoints = np.count_nonzero(keypoints.max(axis=2))
+
+ data_info = {
+ 'img_id': ann['image_id'],
+ 'img_path': img['img_path'],
+ # Addition: add image width and height
+ 'width': img_w,
+ 'height': img_h,
+ 'area': ann['area'],
+ 'bbox': bbox,
+ 'bbox_score': np.ones(1, dtype=np.float32),
+ 'num_keypoints': num_keypoints,
+ 'keypoints': keypoints,
+ 'keypoints_visible': keypoints_visible,
+ 'iscrowd': ann.get('iscrowd', 0),
+ 'segmentation': ann.get('segmentation', None),
+ 'id': ann['id'],
+ 'category_id': ann['category_id'],
+ # store the raw annotation of the instance
+ # it is useful for evaluation without providing ann_file
+ 'raw_ann_info': copy.deepcopy(ann),
+ }
+
+ if 'crowdIndex' in img:
+ data_info['crowd_index'] = img['crowdIndex']
+
+ return data_info
+
+ def _get_bottomup_data_infos(self, instance_list: List[Dict],
+ image_list: List[Dict]) -> List[Dict]:
+ """Organize the data list in bottom-up mode."""
+
+ # bottom-up data list
+ data_list_bu = []
+
+ used_img_ids = set()
+
+ # group instances by img_id
+ for img_id, data_infos in groupby(instance_list,
+ lambda x: x['img_id']):
+ used_img_ids.add(img_id)
+ data_infos = list(data_infos)
+
+ # image data
+ img_path = data_infos[0]['img_path']
+ data_info_bu = {
+ 'img_id': img_id,
+ 'img_path': img_path,
+ 'width': data_infos[0]['width'],
+ 'height': data_infos[0]['height'],
+ }
+
+ for key in data_infos[0].keys():
+ if key not in data_info_bu:
+ seq = [d[key] for d in data_infos]
+ if isinstance(seq[0], np.ndarray):
+ seq = np.concatenate(seq, axis=0)
+ seq = np.array(seq) if key == 'area' else seq
+ data_info_bu[key] = seq
+
+ # The segmentation annotation of invalid objects will be used
+ # to generate valid region mask in the pipeline.
+ invalid_segs = []
+ for data_info_invalid in filterfalse(self._is_valid_instance,
+ data_infos):
+ if 'segmentation' in data_info_invalid:
+ invalid_segs.append(data_info_invalid['segmentation'])
+ data_info_bu['invalid_segs'] = invalid_segs
+
+ data_list_bu.append(data_info_bu)
+
+ # add images without instance for evaluation
+ if self.test_mode:
+ for img_info in image_list:
+ if img_info['img_id'] not in used_img_ids:
+ data_info_bu = {
+ 'img_id': img_info['img_id'],
+ 'img_path': img_info['img_path'],
+ 'id': list(),
+ 'raw_ann_info': None,
+ }
+ data_list_bu.append(data_info_bu)
+
+ return data_list_bu
diff --git a/projects/petr/datasets/transforms.py b/projects/petr/datasets/transforms.py
new file mode 100644
index 0000000000..efd130391f
--- /dev/null
+++ b/projects/petr/datasets/transforms.py
@@ -0,0 +1,276 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import random
+from typing import Union
+
+import numpy as np
+from mmcv.transforms import BaseTransform
+from mmcv.transforms.utils import cache_randomness
+from mmdet.datasets.transforms import FilterAnnotations as FilterDetAnnotations
+from mmdet.datasets.transforms import PackDetInputs
+from mmdet.datasets.transforms import RandomAffine as MMDET_RandomAffine
+from mmdet.registry import TRANSFORMS
+from mmdet.structures.bbox.box_type import autocast_box_type
+from mmengine.structures import PixelData
+
+from mmpose.codecs.utils import generate_gaussian_heatmaps
+from .bbox_keypoint_structure import BBoxKeypoints
+
+
+@TRANSFORMS.register_module(force=True)
+class PoseToDetConverter(BaseTransform):
+ """This transform converts the pose data element into a format that is
+ suitable for the mmdet transforms."""
+
+ def transform(self, results: dict) -> dict:
+
+ results['seg_map_path'] = None
+ results['height'] = results['img_shape'][0]
+ results['width'] = results['img_shape'][1]
+
+ num_instances = len(results.get('bbox', []))
+
+ if num_instances == 0:
+ results['bbox'] = np.empty((0, 4), dtype=np.float32)
+ results['keypoints'] = np.empty(
+ (0, len(results['flip_indices']), 2), dtype=np.float32)
+ results['keypoints_visible'] = np.empty(
+ (0, len(results['flip_indices'])), dtype=np.int32)
+ results['area'] = np.empty((0, ), dtype=np.int32)
+ results['category_id'] = []
+
+ results['gt_bboxes'] = BBoxKeypoints(
+ data=results['bbox'],
+ keypoints=results['keypoints'],
+ keypoints_visible=results['keypoints_visible'],
+ area=results['area'],
+ flip_indices=results['flip_indices'],
+ )
+
+ results['gt_ignore_flags'] = np.array([False] * num_instances)
+ results['gt_bboxes_labels'] = np.array(results['category_id']) - 1
+
+ return results
+
+
+@TRANSFORMS.register_module(force=True)
+class PackDetPoseInputs(PackDetInputs):
+ mapping_table = {
+ 'gt_bboxes': 'bboxes',
+ 'gt_bboxes_labels': 'labels',
+ 'gt_masks': 'masks',
+ 'gt_keypoints': 'keypoints',
+ 'gt_keypoints_visible': 'keypoints_visible',
+ 'gt_area': 'area'
+ }
+ field_mapping_table = {
+ 'gt_heatmaps': 'gt_heatmaps',
+ }
+
+ def __init__(self,
+ meta_keys=('id', 'img_id', 'img_path', 'ori_shape',
+ 'img_shape', 'scale_factor', 'flip',
+ 'flip_direction', 'flip_indices', 'raw_ann_info'),
+ pack_transformed=False):
+ self.meta_keys = meta_keys
+
+ def transform(self, results: dict) -> dict:
+ results['gt_keypoints'] = results['gt_bboxes'].keypoints
+ results['gt_keypoints_visible'] = results[
+ 'gt_bboxes'].keypoints_visible
+ results['gt_area'] = results['gt_bboxes'].area
+
+ # pack fields
+ gt_fields = None
+ for key, packed_key in self.field_mapping_table.items():
+ if key in results:
+
+ if gt_fields is None:
+ gt_fields = PixelData()
+ else:
+ assert isinstance(
+ gt_fields, PixelData
+ ), 'Got mixed single-level and multi-level pixel data.'
+
+ gt_fields.set_field(results[key], packed_key)
+
+ results = super().transform(results)
+ if gt_fields:
+ results['data_samples'].gt_fields = gt_fields.to_tensor()
+
+ return results
+
+
+@TRANSFORMS.register_module()
+class GenerateHeatmap(BaseTransform):
+ """This class is responsible for creating a heatmap based on bounding boxes
+ and keypoints.
+
+ It generates a Gaussian heatmap for each instance in the image, given the
+ bounding box and keypoints.
+ """
+
+ def _get_instance_wise_sigmas(self,
+ bbox: np.ndarray,
+ heatmap_min_overlap: float = 0.9
+ ) -> np.ndarray:
+ """Compute sigma values for each instance based on their bounding box
+ size.
+
+ Args:
+ bbox (np.ndarray): Bounding box in shape (N, 4, 2)
+ heatmap_min_overlap (float, optional): Minimum overlap for
+ the heatmap. Defaults to 0.9.
+
+ Returns:
+ np.ndarray: Array containing the sigma values for each instance.
+ """
+ sigmas = np.zeros((bbox.shape[0], ), dtype=np.float32)
+
+ heights = bbox[:, 3] - bbox[:, 1]
+ widths = bbox[:, 2] - bbox[:, 0]
+
+ for i in range(bbox.shape[0]):
+ h, w = heights[i], widths[i]
+
+ # compute sigma for each instance
+ # condition 1
+ a1, b1 = 1, h + w
+ c1 = w * h * (1 - heatmap_min_overlap) / (1 + heatmap_min_overlap)
+ sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
+ r1 = (b1 + sq1) / 2
+
+ # condition 2
+ a2 = 4
+ b2 = 2 * (h + w)
+ c2 = (1 - heatmap_min_overlap) * w * h
+ sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
+ r2 = (b2 + sq2) / 2
+
+ # condition 3
+ a3 = 4 * heatmap_min_overlap
+ b3 = -2 * heatmap_min_overlap * (h + w)
+ c3 = (heatmap_min_overlap - 1) * w * h
+ sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
+ r3 = (b3 + sq3) / 2
+
+ sigmas[i] = min(r1, r2, r3) / 3
+
+ return sigmas
+
+ @autocast_box_type()
+ def transform(self, results: dict) -> Union[dict, None]:
+ """Apply the transformation to filter annotations.
+
+ This function rescales bounding boxes and keypoints, and generates a
+ Gaussian heatmap for each instance.
+ """
+
+ assert 'gt_bboxes' in results
+
+ bbox = results['gt_bboxes'].tensor.numpy() / 8
+ keypoints = results['gt_bboxes'].keypoints.numpy() / 8
+ keypoints_visible = results['gt_bboxes'].keypoints_visible.numpy()
+
+ heatmap_size = [
+ results['img_shape'][1] // 8 + 1, results['img_shape'][0] // 8 + 1
+ ]
+ sigmas = self._get_instance_wise_sigmas(bbox)
+
+ hm, _ = generate_gaussian_heatmaps(heatmap_size, keypoints,
+ keypoints_visible, sigmas)
+
+ results['gt_heatmaps'] = hm
+
+ return results
+
+
+@TRANSFORMS.register_module()
+class FilterDetPoseAnnotations(FilterDetAnnotations):
+ """Filter invalid annotations.
+
+ In addition to the conditions checked by ``FilterDetAnnotations``, this
+ filter adds a new condition requiring instances to have at least one
+ visible keypoints.
+ """
+
+ @autocast_box_type()
+ def transform(self, results: dict) -> Union[dict, None]:
+ """Transform function to filter annotations.
+
+ Args:
+ results (dict): Result dict.
+
+ Returns:
+ dict: Updated result dict.
+ """
+ assert 'gt_bboxes' in results
+ gt_bboxes = results['gt_bboxes']
+ if gt_bboxes.shape[0] == 0:
+ return results
+
+ tests = []
+ if self.by_box:
+ tests.append(((gt_bboxes.widths > self.min_gt_bbox_wh[0]) &
+ (gt_bboxes.heights > self.min_gt_bbox_wh[1]) &
+ (gt_bboxes.num_keypoints > 0)).numpy())
+
+ if self.by_mask:
+ assert 'gt_masks' in results
+ gt_masks = results['gt_masks']
+ tests.append(gt_masks.areas >= self.min_gt_mask_area)
+
+ keep = tests[0]
+ for t in tests[1:]:
+ keep = keep & t
+
+ if not keep.any():
+ if self.keep_empty:
+ return None
+
+ keys = ('gt_bboxes', 'gt_bboxes_labels', 'gt_masks', 'gt_ignore_flags')
+ for key in keys:
+ if key in results:
+ results[key] = results[key][keep]
+
+ return results
+
+
+@TRANSFORMS.register_module(force=True)
+class RandomAffine(MMDET_RandomAffine):
+
+ @cache_randomness
+ def _get_random_homography_matrix(self, height, width):
+
+ # Center
+ center_matrix = np.eye(3, dtype=np.float32)
+ center_matrix[0, 2] = -width / 2 # x translation (pixels)
+ center_matrix[1, 2] = -height / 2 # y translation (pixels)
+
+ # Rotation
+ rotation_degree = random.uniform(-self.max_rotate_degree,
+ self.max_rotate_degree)
+ rotation_matrix = self._get_rotation_matrix(rotation_degree)
+
+ # Scaling
+ scaling_ratio = random.uniform(self.scaling_ratio_range[0],
+ self.scaling_ratio_range[1])
+ scaling_matrix = self._get_scaling_matrix(scaling_ratio)
+
+ # Shear
+ x_degree = random.uniform(-self.max_shear_degree,
+ self.max_shear_degree)
+ y_degree = random.uniform(-self.max_shear_degree,
+ self.max_shear_degree)
+ shear_matrix = self._get_shear_matrix(x_degree, y_degree)
+
+ # Translation
+ trans_x = random.uniform(0.5 - self.max_translate_ratio,
+ 0.5 + self.max_translate_ratio) * width
+ trans_y = random.uniform(0.5 - self.max_translate_ratio,
+ 0.5 + self.max_translate_ratio) * height
+ translate_matrix = self._get_translation_matrix(trans_x, trans_y)
+
+ warp_matrix = (
+ translate_matrix @ shear_matrix @ rotation_matrix @ scaling_matrix
+ @ center_matrix)
+ return warp_matrix
diff --git a/projects/petr/demo b/projects/petr/demo
new file mode 120000
index 0000000000..bf71256cd3
--- /dev/null
+++ b/projects/petr/demo
@@ -0,0 +1 @@
+../../demo
\ No newline at end of file
diff --git a/projects/petr/models/__init__.py b/projects/petr/models/__init__.py
new file mode 100644
index 0000000000..e9b7e5f0c1
--- /dev/null
+++ b/projects/petr/models/__init__.py
@@ -0,0 +1,4 @@
+from .losses import * # noqa
+from .match_costs import * # noqa
+from .petr import * # noqa
+from .petr_head import * # noqa
diff --git a/projects/petr/models/losses.py b/projects/petr/models/losses.py
new file mode 100644
index 0000000000..1cf4d9a4a8
--- /dev/null
+++ b/projects/petr/models/losses.py
@@ -0,0 +1,86 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from mmdet.registry import MODELS
+from torch import Tensor
+
+from mmpose.datasets.datasets.utils import parse_pose_metainfo
+
+
+@MODELS.register_module(force=True)
+class OksLoss(nn.Module):
+ """A PyTorch implementation of the Object Keypoint Similarity (OKS) loss as
+ described in the paper "End-to-End Multi-Person Pose Estimation with
+ Transformers" by Shi et al. (2022).
+
+ The OKS loss is used for keypoint-based object recognition and consists
+ of a measure of the similarity between predicted and ground truth
+ keypoint locations, adjusted by the size of the object in the image.
+
+ The loss function takes as input the predicted keypoint locations, the
+ ground truth keypoint locations, a mask indicating which keypoints are
+ valid, and bounding boxes for the objects.
+
+ Args:
+ metainfo (Optional[str]): Path to a JSON file containing information
+ about the dataset's annotations.
+ loss_weight (float): Weight for the loss.
+ """
+
+ def __init__(self,
+ metainfo: Optional[str] = None,
+ loss_weight: float = 1.0):
+ super().__init__()
+
+ if metainfo is not None:
+ metainfo = parse_pose_metainfo(dict(from_file=metainfo))
+ sigmas = metainfo.get('sigmas', None)
+ if sigmas is not None:
+ self.register_buffer('sigmas', torch.as_tensor(sigmas))
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ output: Tensor,
+ target: Tensor,
+ target_weights: Tensor,
+ bboxes: Optional[Tensor] = None) -> Tensor:
+ oks = self.compute_oks(output, target, target_weights, bboxes)
+ loss = -torch.log(oks.clamp(min=1e-6))
+ return loss.mean() * self.loss_weight
+
+ def compute_oks(self,
+ output: Tensor,
+ target: Tensor,
+ target_weights: Tensor,
+ area: Optional[Tensor] = None) -> Tensor:
+ """Calculates the OKS metric.
+
+ Args:
+ output (Tensor): Predicted keypoints in shape N x k x 2, where N
+ is batch size, k is the number of keypoints, and 2 are the
+ xy coordinates.
+ target (Tensor): Ground truth keypoints in the same shape as
+ output.
+ target_weights (Tensor): Mask of valid keypoints in shape N x k,
+ with 1 for valid and 0 for invalid.
+ bboxes (Optional[Tensor]): Bounding boxes in shape N x 4,
+ where 4 are the xyxy coordinates.
+
+ Returns:
+ Tensor: The calculated OKS.
+ """
+
+ dist = torch.norm(output - target, dim=-1)
+
+ if hasattr(self, 'sigmas'):
+ sigmas = self.sigmas.reshape(*((1, ) * (dist.ndim - 1)), -1)
+ if sigmas.device != dist.device:
+ sigmas = sigmas.to(dist.device)
+ dist = dist / (sigmas * 2)
+ if area is not None:
+ dist = dist / area.pow(0.5).clip(min=1e-8).unsqueeze(-1)
+
+ return (torch.exp(-dist.pow(2) / 2) * target_weights).sum(
+ dim=-1) / target_weights.sum(dim=-1).clip(min=1e-8)
diff --git a/projects/petr/models/match_costs.py b/projects/petr/models/match_costs.py
new file mode 100644
index 0000000000..de239c3b17
--- /dev/null
+++ b/projects/petr/models/match_costs.py
@@ -0,0 +1,101 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+from mmdet.models.task_modules.assigners.match_cost import BaseMatchCost
+from mmdet.registry import TASK_UTILS
+from mmengine.structures import InstanceData
+from torch import Tensor
+
+from .losses import OksLoss
+
+
+@TASK_UTILS.register_module()
+class KptL1Cost(BaseMatchCost):
+ """This class computes the L1 cost between predicted and ground truth
+ keypoints.
+
+ The cost is computed based on the normalized difference between the
+ keypoints. The keypoints visibility is also taken into account while
+ calculating the cost.
+ """
+
+ def __call__(self,
+ pred_instances: InstanceData,
+ gt_instances: InstanceData,
+ img_meta: Optional[dict] = None,
+ **kwargs) -> Tensor:
+ """Compute the L1 cost between predicted and ground truth keypoints.
+
+ Args:
+ pred_instances (InstanceData): Predicted instances data.
+ gt_instances (InstanceData): Ground truth instances data.
+ img_meta (dict, optional): Meta data of the image. Defaults
+ to None.
+
+ Returns:
+ Tensor: L1 cost between predicted and ground truth keypoints.
+ """
+
+ # Extract keypoints from predicted and ground truth instances
+ pred_keypoints = pred_instances.keypoints
+ gt_keypoints = gt_instances.keypoints
+
+ # Get the visibility of keypoints and normalize it
+ gt_keypoints_visible = gt_instances.keypoints_visible
+ gt_keypoints_visible = gt_keypoints_visible / (
+ 2 * gt_keypoints_visible.sum(dim=1, keepdim=True) + 1e-8)
+
+ # Normalize keypoints based on image shape
+ img_h, img_w = img_meta['img_shape']
+ factor = gt_keypoints.new_tensor([img_w, img_h]).reshape(1, 1, 2)
+ gt_keypoints = (gt_keypoints / factor).unsqueeze(0)
+ gt_keypoints_visible = gt_keypoints_visible.unsqueeze(0).unsqueeze(-1)
+ pred_keypoints = (pred_keypoints / factor).unsqueeze(1)
+
+ # Calculate L1 cost considering visibility of keypoints
+ diff = (pred_keypoints - gt_keypoints) * gt_keypoints_visible
+ kpt_cost = diff.flatten(2).norm(dim=2, p=1)
+
+ return kpt_cost * self.weight
+
+
+@TASK_UTILS.register_module()
+class OksCost(BaseMatchCost, OksLoss):
+ """This class computes the OKS (Object Keypoint Similarity) cost between
+ predicted and ground truth keypoints.
+
+ It normalizes keypoints based on image shape, then calculates the OKS using
+ a method from the OksLoss class. It also includes visibility and bounding
+ box information in the calculation.
+ """
+
+ def __init__(self, metainfo: Optional[str] = None, weight: float = 1.0):
+ OksLoss.__init__(self, metainfo, weight)
+ self.weight = self.loss_weight
+
+ def __call__(self,
+ pred_instances: InstanceData,
+ gt_instances: InstanceData,
+ img_meta: Optional[dict] = None,
+ **kwargs) -> Tensor:
+ """Compute the OKS cost between predicted and ground truth keypoints.
+
+ Args:
+ pred_instances (InstanceData): Predicted instances data.
+ gt_instances (InstanceData): Ground truth instances data.
+ img_meta (dict, optional): Meta data of the image. Defaults
+ to None.
+
+ Returns:
+ Tensor: OKS cost between predicted and ground truth keypoints.
+ """
+
+ # Extract keypoints and bounding boxes
+ pred_keypoints = pred_instances.keypoints.unsqueeze(1)
+ gt_keypoints = gt_instances.keypoints.unsqueeze(0)
+ gt_keypoints_visible = gt_instances.keypoints_visible
+ gt_areas = gt_instances.area.reshape(1, gt_keypoints.size(1))
+ # Calculate OKS cost
+ kpt_cost = self.compute_oks(pred_keypoints, gt_keypoints,
+ gt_keypoints_visible, gt_areas)
+ return -kpt_cost * self.weight
diff --git a/projects/petr/models/petr.py b/projects/petr/models/petr.py
new file mode 100644
index 0000000000..1a9673dfb2
--- /dev/null
+++ b/projects/petr/models/petr.py
@@ -0,0 +1,691 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from collections import OrderedDict
+from typing import Dict, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention
+from mmdet.models.detectors import DeformableDETR
+from mmdet.models.layers import (DeformableDetrTransformerDecoder,
+ DeformableDetrTransformerEncoder,
+ SinePositionalEncoding)
+from mmdet.models.layers.transformer.utils import inverse_sigmoid
+from mmdet.registry import MODELS
+from mmdet.structures import OptSampleList, SampleList
+from mmengine.model import xavier_init
+from torch import Tensor, nn
+from torch.nn.init import normal_
+
+from .transformers import (MultiScaleDeformablePoseAttention,
+ PetrTransformerDecoder)
+
+
+@MODELS.register_module()
+class PETR(DeformableDETR):
+ r"""Implementation of `End-to-End Multi-Person Pose Estimation with
+ Transformers `_
+
+ Code is modified from the `official github repo
+ `_.
+
+ Args:
+ num_keypoints (int): Numbder of Keypoints. Defaults to 17.
+ hm_encoder (:obj:`ConfigDict` or dict, optional): Config of the
+ heatmap encoder. Defaults to None.
+ kpt_decoder (:obj:`ConfigDict` or dict, optional): Config for the
+ keypoint refine decoder. Defaults to None.
+ """
+ _version = 2
+
+ def __init__(self,
+ num_keypoints: int = 17,
+ hm_encoder: dict = None,
+ kpt_decoder: dict = None,
+ *args,
+ **kwargs):
+ self.num_keypoints = num_keypoints
+ self.hm_encoder = hm_encoder
+ self.kpt_decoder = kpt_decoder
+ super().__init__(*args, **kwargs)
+
+ self._register_load_state_dict_pre_hook(self._load_state_dict_pre_hook)
+ self.encoder._register_load_state_dict_pre_hook(
+ self._load_state_dict_pre_hook)
+ self.decoder._register_load_state_dict_pre_hook(
+ self._load_state_dict_pre_hook)
+
+ def _init_layers(self) -> None:
+ """Initialize layers except for backbone, neck and bbox_head."""
+ self.positional_encoding = SinePositionalEncoding(
+ **self.positional_encoding)
+ self.encoder = DeformableDetrTransformerEncoder(**self.encoder)
+ self.decoder = PetrTransformerDecoder(**self.decoder)
+ self.hm_encoder = DeformableDetrTransformerEncoder(**self.hm_encoder)
+ self.kpt_decoder = DeformableDetrTransformerDecoder(**self.kpt_decoder)
+ self.embed_dims = self.encoder.embed_dims
+ self.query_embedding = nn.Embedding(self.num_queries,
+ self.embed_dims * 2)
+ self.kpt_query_embedding = nn.Embedding(self.num_keypoints,
+ self.embed_dims * 2)
+
+ num_feats = self.positional_encoding.num_feats
+ assert num_feats * 2 == self.embed_dims, \
+ 'embed_dims should be exactly 2 times of num_feats. ' \
+ f'Found {self.embed_dims} and {num_feats}.'
+
+ self.level_embed = nn.Parameter(
+ torch.Tensor(self.num_feature_levels, self.embed_dims))
+
+ if self.as_two_stage:
+ self.memory_trans_fc = nn.Linear(self.embed_dims, self.embed_dims)
+ self.memory_trans_norm = nn.LayerNorm(self.embed_dims)
+ else:
+ self.reference_points_fc = nn.Linear(self.embed_dims, 2)
+
+ def init_weights(self) -> None:
+ """Initialize weights for Transformer and other components."""
+ super(DeformableDETR, self).init_weights()
+ for coder in self.encoder, self.decoder:
+ for p in coder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for m in self.modules():
+ if isinstance(m, MultiScaleDeformableAttention):
+ m.init_weights()
+ for m in self.modules():
+ if isinstance(m, MultiScaleDeformablePoseAttention):
+ m.init_weights()
+ if self.as_two_stage:
+ nn.init.xavier_uniform_(self.memory_trans_fc.weight)
+ else:
+ xavier_init(
+ self.reference_points_fc, distribution='uniform', bias=0.)
+ normal_(self.level_embed)
+ normal_(self.kpt_query_embedding.weight)
+
+ def forward_transformer(self,
+ img_feats: Tuple[Tensor],
+ batch_data_samples: OptSampleList = None,
+ test_mode: bool = True) -> Dict:
+ """Forward process of Transformer in PETR.
+
+ This function consists of seven stages: 'pre_transformer',
+ 'forward_encoder', 'pre_decoder', 'forward_decoder',
+ 'pre_kpt_decoder', 'forward_kpt_decoder', and 'forward_kpt_head'.
+ It takes image features (img_feats) and batch data samples
+ (batch_data_samples) as inputs and performs transformations at
+ each stage. The output is a dictionary of inputs to be used for the
+ bounding box head function (bbox_head).
+
+ .. code:: text
+
+ img_feats & batch_data_samples
+ |
+ V
+ +-----------------+
+ | pre_transformer |
+ +-----------------+
+ | |
+ | V
+ | +-----------------+
+ | | forward_encoder |
+ | +-----------------+
+ | |
+ | V
+ | +---------------+
+ | | pre_decoder |
+ | +---------------+
+ | | |
+ V V |
+ +-----------------+ |
+ | forward_decoder | |
+ +-----------------+ |
+ | |
+ V V
+ +-----------------+ |
+ | pre_kpt_decoder | |
+ +-----------------+ |
+ | |
+ V V
+ +--------------------+ |
+ | forward_kpt_decoder| |
+ +--------------------+ |
+ | |
+ V V
+ +----------------+ |
+ |forward_kpt_head| |
+ +----------------+ |
+ | |
+ V V
+ head_inputs_dict
+
+
+ Args:
+ img_feats (tuple[Tensor]): Tuple of feature maps from neck. Each
+ feature map has shape (bs, dim, H, W).
+ batch_data_samples (list[:obj:`DetDataSample`], optional): The
+ batch data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+ Defaults to None.
+ test_mode (bool, optional): If True, the function operates in test
+ mode. Defaults to True.
+
+ Returns:
+ head_inputs_dict (dict): The dictionary of bbox_head function
+ inputs. Always includes 'hidden_states' from the decoder output
+ and may contain 'references' including the initial and
+ intermediate references. The specific contents of this dict
+ differ based on whether the function is operating in test_mode
+ or not. In test_mode, 'det_labels' and 'det_scores' are
+ included. In training mode, it includes additional elements
+ such as 'enc_outputs_class', 'enc_outputs_coord',
+ 'all_layers_classes', 'all_layers_coords', 'hm_memory',
+ and 'hm_mask'.
+ """
+ encoder_inputs_dict, decoder_inputs_dict = self.pre_transformer(
+ img_feats, batch_data_samples)
+
+ encoder_outputs_dict, heatmap_dict = self.forward_encoder(
+ **encoder_inputs_dict, test_mode=test_mode)
+
+ tmp_dec_in, head_inputs_dict = self.pre_decoder(**encoder_outputs_dict)
+
+ decoder_inputs_dict.update(tmp_dec_in)
+
+ decoder_outputs_dict = self.forward_decoder(**decoder_inputs_dict)
+
+ kpt_decoder_inputs_dict = self.pre_kpt_decoder(
+ **decoder_outputs_dict,
+ batch_data_samples=batch_data_samples,
+ test_mode=test_mode)
+
+ kpt_decoder_inputs_dict.update(decoder_inputs_dict)
+
+ kpt_decoder_outputs_dict = self.forward_kpt_decoder(
+ **kpt_decoder_inputs_dict)
+
+ dec_outputs_coord = self.forward_kpt_head(**kpt_decoder_outputs_dict)
+
+ head_inputs_dict['dec_outputs_coord'] = dec_outputs_coord
+ if test_mode:
+ head_inputs_dict['det_labels'] = kpt_decoder_inputs_dict[
+ 'det_labels']
+ head_inputs_dict['det_scores'] = kpt_decoder_inputs_dict[
+ 'det_scores']
+ else:
+ head_inputs_dict.update(heatmap_dict)
+ head_inputs_dict['all_layers_classes'] = decoder_outputs_dict[
+ 'all_layers_classes']
+ head_inputs_dict['all_layers_coords'] = decoder_outputs_dict[
+ 'all_layers_coords']
+
+ return head_inputs_dict
+
+ def loss(self, batch_inputs: Tensor,
+ batch_data_samples: SampleList) -> Union[dict, list]:
+ """Calculate losses from a batch of inputs and data samples.
+
+ Args:
+ batch_inputs (Tensor): Input images of shape (bs, dim, H, W).
+ These should usually be mean centered and std scaled.
+ batch_data_samples (List[:obj:`DetDataSample`]): The batch
+ data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+
+ Returns:
+ dict: A dictionary of loss components
+ """
+
+ img_feats = self.extract_feat(batch_inputs)
+ head_inputs_dict = self.forward_transformer(
+ img_feats, batch_data_samples, test_mode=False)
+
+ losses = self.bbox_head.loss(
+ **head_inputs_dict, batch_data_samples=batch_data_samples)
+
+ return losses
+
+ def predict(self,
+ batch_inputs: Tensor,
+ batch_data_samples: SampleList,
+ rescale: bool = True) -> SampleList:
+ """Predict results from a batch of inputs and data samples with post-
+ processing.
+
+ Args:
+ batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W).
+ batch_data_samples (List[:obj:`DetDataSample`]): The batch
+ data samples. It usually includes information such
+ as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`.
+ rescale (bool): Whether to rescale the results.
+ Defaults to True.
+
+ Returns:
+ list[:obj:`DetDataSample`]: Detection results of the input images.
+ Each DetDataSample usually contain 'pred_instances'. And the
+ `pred_instances` usually contains following keys.
+
+ - scores (Tensor): Classification scores, has a shape
+ (num_instance, )
+ - labels (Tensor): Labels of bboxes, has a shape
+ (num_instances, ).
+ - bboxes (Tensor): Has a shape (num_instances, 4),
+ the last dimension 4 arrange as (x1, y1, x2, y2).
+ """
+
+ img_feats = self.extract_feat(batch_inputs)
+ head_inputs_dict = self.forward_transformer(
+ img_feats, batch_data_samples, test_mode=True)
+
+ results_list = self.bbox_head.predict(
+ **head_inputs_dict,
+ rescale=rescale,
+ batch_data_samples=batch_data_samples)
+ batch_data_samples = self.add_pred_to_datasample(
+ batch_data_samples, results_list)
+ return batch_data_samples
+
+ def forward_encoder(self,
+ feat: Tensor,
+ feat_mask: Tensor,
+ feat_pos: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ valid_ratios: Tensor,
+ test_mode: bool = True) -> Dict:
+ """Forward with Transformer encoder.
+
+ Args:
+ feat (Tensor): Sequential features, has shape (bs, num_feat_points,
+ dim).
+ feat_mask (Tensor): ByteTensor, the padding mask of the features,
+ has shape (bs, num_feat_points).
+ feat_pos (Tensor): The positional embeddings of the features, has
+ shape (bs, num_feat_points, dim).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels, ) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+ test_mode (bool, optional): If True, the function operates in test
+ mode. Defaults to True.
+
+ Returns:
+ dict: The dictionary of encoder outputs, which includes the
+ `memory` of the encoder output.
+ """
+ memory = self.encoder(
+ query=feat,
+ query_pos=feat_pos,
+ key_padding_mask=feat_mask, # for self_attn
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios)
+
+ encoder_outputs_dict = dict(
+ memory=memory,
+ memory_mask=feat_mask,
+ spatial_shapes=spatial_shapes)
+
+ # only for training
+ heatmap_dict = dict()
+ if not test_mode:
+ batch_size = memory.size(0)
+ hm_memory = memory[:, level_start_index[0]:level_start_index[1], :]
+ hm_pos_embed = feat_pos[:, level_start_index[0]:
+ level_start_index[1], :]
+ hm_mask = feat_mask[:, level_start_index[0]:level_start_index[1]]
+ hm_memory = self.hm_encoder(
+ query=hm_memory,
+ query_pos=hm_pos_embed,
+ key_padding_mask=hm_mask,
+ spatial_shapes=spatial_shapes.narrow(0, 0, 1),
+ level_start_index=level_start_index[0],
+ valid_ratios=valid_ratios.narrow(1, 0, 1))
+ hm_memory = hm_memory.reshape(batch_size, spatial_shapes[0, 0],
+ spatial_shapes[0, 1], -1)
+ hm_mask = hm_mask.reshape(batch_size, spatial_shapes[0, 0],
+ spatial_shapes[0, 1])
+
+ heatmap_dict['hm_memory'] = hm_memory
+ heatmap_dict['hm_mask'] = hm_mask
+
+ return encoder_outputs_dict, heatmap_dict
+
+ def pre_decoder(self, memory: Tensor, memory_mask: Tensor,
+ spatial_shapes: Tensor) -> Tuple[Dict, Dict]:
+ """Prepare intermediate variables before entering Transformer decoder,
+ such as `query`, `query_pos`, and `reference_points`.
+
+ Args:
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (bs, num_feat_points, dim).
+ memory_mask (Tensor): ByteTensor, the padding mask of the memory,
+ has shape (bs, num_feat_points). It will only be used when
+ `as_two_stage` is `True`.
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ It will only be used when `as_two_stage` is `True`.
+
+ Returns:
+ tuple[dict, dict]: The decoder_inputs_dict and head_inputs_dict.
+
+ - decoder_inputs_dict (dict): The keyword dictionary args of
+ `self.forward_decoder()`, which includes 'query', 'query_pos',
+ 'memory', and `reference_points`. The reference_points of
+ decoder input here are 4D boxes when `as_two_stage` is `True`,
+ otherwise 2D points, although it has `points` in its name.
+ The reference_points in encoder is always 2D points.
+ - head_inputs_dict (dict): The keyword dictionary args of the
+ bbox_head functions, which includes `enc_outputs_class` and
+ `enc_outputs_coord`. They are both `None` when 'as_two_stage'
+ is `False`. The dict is empty when `self.training` is `False`.
+ """
+ batch_size, _, c = memory.shape
+
+ query_embed = self.query_embedding.weight
+ query_pos, query = torch.split(query_embed, c, dim=1)
+ query_pos = query_pos.unsqueeze(0).expand(batch_size, -1, -1)
+ query = query.unsqueeze(0).expand(batch_size, -1, -1)
+
+ if self.as_two_stage:
+ output_memory, output_proposals = \
+ self.gen_encoder_output_proposals(
+ memory, memory_mask, spatial_shapes)
+ enc_outputs_class = self.bbox_head.cls_branches[
+ self.decoder.num_layers](
+ output_memory)
+ enc_outputs_coord_unact = self.bbox_head.reg_branches[
+ self.decoder.num_layers](
+ output_memory)
+ enc_outputs_coord_unact[..., 0::2] += output_proposals[..., 0:1]
+ enc_outputs_coord_unact[..., 1::2] += output_proposals[..., 1:2]
+ enc_outputs_coord = enc_outputs_coord_unact.sigmoid()
+ topk_proposals = torch.topk(
+ enc_outputs_class[..., 0], self.num_queries, dim=1)[1]
+ topk_coords_unact = torch.gather(
+ enc_outputs_coord_unact, 1,
+ topk_proposals.unsqueeze(-1).repeat(
+ 1, 1, enc_outputs_coord_unact.size(-1)))
+ topk_coords_unact = topk_coords_unact.detach()
+ reference_points = topk_coords_unact.sigmoid()
+ else:
+ enc_outputs_class, enc_outputs_coord = None, None
+ reference_points = self.reference_points_fc(query_pos).sigmoid()
+
+ decoder_inputs_dict = dict(
+ query=query,
+ query_pos=query_pos,
+ memory=memory,
+ reference_points=reference_points)
+ head_inputs_dict = dict(
+ enc_outputs_class=enc_outputs_class,
+ enc_outputs_coord=enc_outputs_coord) if self.training else dict()
+ return decoder_inputs_dict, head_inputs_dict
+
+ def forward_decoder(self, query: Tensor, query_pos: Tensor, memory: Tensor,
+ memory_mask: Tensor, reference_points: Tensor,
+ spatial_shapes: Tensor, level_start_index: Tensor,
+ valid_ratios: Tensor) -> Dict:
+ """Forward with Transformer decoder.
+
+ The forward procedure of the transformer is defined as:
+ 'pre_transformer' -> 'encoder' -> 'pre_decoder' -> 'decoder'
+ More details can be found at `TransformerDetector.forward_transformer`
+ in `mmdet/detector/base_detr.py`.
+
+ Args:
+ query (Tensor): The queries of decoder inputs, has shape
+ (bs, num_queries, dim).
+ query_pos (Tensor): The positional queries of decoder inputs,
+ has shape (bs, num_queries, dim).
+ memory (Tensor): The output embeddings of the Transformer encoder,
+ has shape (bs, num_feat_points, dim).
+ memory_mask (Tensor): ByteTensor, the padding mask of the memory,
+ has shape (bs, num_feat_points).
+ reference_points (Tensor): The initial reference, has shape
+ (bs, num_queries, 4) with the last dimension arranged as
+ (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has
+ shape (bs, num_queries, 2) with the last dimension arranged as
+ (cx, cy).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels, ) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+
+ Returns:
+ dict: The dictionary of decoder outputs, which includes the
+ `hidden_states` of the decoder output and `references` including
+ the initial and intermediate reference_points.
+ """
+ inter_states, inter_references = self.decoder(
+ query=query,
+ value=memory,
+ query_pos=query_pos,
+ key_padding_mask=memory_mask, # for cross_attn
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reg_branches=self.bbox_head.reg_branches
+ if self.with_box_refine else None)
+ references = [reference_points, *inter_references]
+
+ all_layers_classes, all_layers_coords = self.bbox_head(
+ hidden_states=inter_states, references=references)
+
+ decoder_outputs_dict = dict(
+ hidden_states=inter_states,
+ references=references,
+ all_layers_classes=all_layers_classes,
+ all_layers_coords=all_layers_coords)
+
+ return decoder_outputs_dict
+
+ def pre_kpt_decoder(self,
+ all_layers_classes,
+ all_layers_coords,
+ batch_data_samples,
+ test_mode=False,
+ **kwargs):
+ """Prepares the inputs for the keypoint decoder.
+
+ Args:
+ all_layers_classes (Tensor): Classification scores of all layers
+ all_layers_coords (Tensor): Coordinates of keypoints of all layers
+ batch_data_samples (list): List of samples in a batch
+ test_mode (bool, optional): If True, the function will run in test
+ mode. Defaults to False.
+ """
+ cls_scores = all_layers_classes[-1]
+ kpt_coords = all_layers_coords[-1]
+
+ if test_mode:
+ assert cls_scores.size(0) == 1, \
+ f'only `batch_size=1` is supported in testing, but got ' \
+ f'{cls_scores.size(0)}'
+
+ cls_scores = cls_scores[0]
+ kpt_coords = kpt_coords[0]
+
+ max_per_img = self.test_cfg['max_per_img']
+ if self.bbox_head.loss_cls.use_sigmoid:
+ cls_scores = cls_scores.sigmoid()
+ scores, indices = cls_scores.view(-1).topk(max_per_img)
+ det_labels = indices % self.bbox_head.num_classes
+ bbox_index = indices // self.bbox_head.num_classes
+ kpt_coords = kpt_coords[bbox_index]
+ else:
+ scores, det_labels = F.softmax(
+ cls_scores, dim=-1)[..., :-1].max(-1)
+ scores, bbox_index = scores.topk(max_per_img)
+ kpt_coords = kpt_coords[bbox_index]
+ det_labels = det_labels[bbox_index]
+
+ kpt_weights = torch.ones_like(kpt_coords)
+
+ kpt_decoder_inputs_dict = dict(
+ det_labels=det_labels,
+ det_scores=scores,
+ )
+
+ else:
+
+ batch_gt_instances = [ds.gt_instances for ds in batch_data_samples]
+ batch_img_metas = [ds.metainfo for ds in batch_data_samples]
+
+ num_imgs = cls_scores.size(0)
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
+ kpt_coords_list = [
+ kpt_coords[i].reshape(-1, self.bbox_head.num_keypoints, 2)
+ for i in range(num_imgs)
+ ]
+
+ cls_reg_targets = self.bbox_head.get_targets(
+ cls_scores_list,
+ kpt_coords_list,
+ batch_gt_instances,
+ batch_img_metas,
+ cache_targets=True)
+ kpt_weights = torch.cat(cls_reg_targets[4])
+ kpt_coords = kpt_coords.flatten(0, 1)
+ kpt_decoder_inputs_dict = {}
+
+ pos_inds = kpt_weights.sum(-1) > 0
+ if pos_inds.sum() == 0:
+ pos_kpt_coords = torch.zeros_like(kpt_coords[:1])
+ pos_img_inds = kpt_coords.new_zeros([1], dtype=torch.int64)
+ else:
+ pos_kpt_coords = kpt_coords[pos_inds]
+ pos_img_inds = (pos_inds.nonzero() /
+ self.num_queries).squeeze(1).to(torch.int64)
+
+ kpt_decoder_inputs_dict.update(
+ dict(
+ pos_kpt_coords=pos_kpt_coords,
+ pos_img_inds=pos_img_inds,
+ ))
+ return kpt_decoder_inputs_dict
+
+ def forward_kpt_decoder(self, memory, memory_mask, pos_kpt_coords,
+ pos_img_inds, spatial_shapes, level_start_index,
+ valid_ratios, **kwargs):
+ """Runs the keypoint decoder forward pass.
+
+ Args:
+ memory (Tensor): The output embeddings from the Transformer
+ encoder.
+ memory_mask (Tensor): The mask of the memory.
+ pos_kpt_coords (Tensor): Positive keypoint coordinates.
+ pos_img_inds (Tensor): Image indices of positive keypoints.
+ spatial_shapes (Tensor): Spatial shapes of features.
+ level_start_index (Tensor): Start index of each level.
+ valid_ratios (Tensor): Valid ratios of all images.
+ """
+
+ kpt_query_embedding = self.kpt_query_embedding.weight
+ query_pos, query = torch.split(
+ kpt_query_embedding, kpt_query_embedding.size(1) // 2, dim=1)
+ pos_num = pos_kpt_coords.size(0)
+ query_pos = query_pos.unsqueeze(0).expand(pos_num, -1, -1)
+ query = query.unsqueeze(0).expand(pos_num, -1, -1)
+ reference_points = pos_kpt_coords.reshape(pos_num,
+ pos_kpt_coords.size(1) // 2,
+ 2).detach()
+ pos_memory = memory[pos_img_inds, :, :]
+ memory_mask = memory_mask[pos_img_inds, :]
+ valid_ratios = valid_ratios[pos_img_inds, ...]
+
+ # forward_kpt_decoder
+ inter_states, inter_references = self.kpt_decoder(
+ query=query,
+ key=None,
+ value=pos_memory,
+ query_pos=query_pos,
+ key_padding_mask=memory_mask,
+ reference_points=reference_points,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reg_branches=self.bbox_head.kpt_branches)
+
+ kpt_decoder_outputs_dict = dict(
+ inter_states=inter_states,
+ reference_points=reference_points,
+ inter_references=inter_references,
+ )
+
+ return kpt_decoder_outputs_dict
+
+ def forward_kpt_head(self, inter_states, reference_points,
+ inter_references):
+ """Runs the keypoint head forward pass.
+
+ Args:
+ inter_states (Tensor): Intermediate states from the keypoint
+ decoder.
+ reference_points (Tensor): Reference points from the keypoint
+ decoder.
+ inter_references (Tensor): Intermediate reference points from
+ the keypoint decoder.
+ """
+
+ outputs_kpts = []
+ for lvl in range(inter_states.shape[0]):
+ if lvl == 0:
+ reference = reference_points
+ else:
+ reference = inter_references[lvl - 1]
+ reference = inverse_sigmoid(reference)
+ tmp_kpt = self.bbox_head.kpt_branches[lvl](inter_states[lvl])
+ assert reference.shape[-1] == 2
+ tmp_kpt += reference
+ outputs_kpt = tmp_kpt.sigmoid()
+ outputs_kpts.append(outputs_kpt)
+
+ outputs_kpts = torch.stack(outputs_kpts)
+ return outputs_kpts
+
+ def _load_state_dict_pre_hook(self, state_dict, prefix, local_meta, *args,
+ **kwargs):
+ """A hook function to convert the state dict from official repo to a
+ compatible format :class:`PETR`.
+
+ The hook will be automatically registered during initialization.
+ """
+
+ if local_meta.get('version', self._version) >= self._version:
+ return
+
+ mappings = OrderedDict()
+ mappings['bbox_head.transformer.'] = ''
+ mappings['level_embeds'] = 'level_embed'
+ mappings['bbox_head.query_embedding'] = 'query_embedding'
+ mappings['refine_query_embedding'] = 'kpt_query_embedding'
+ mappings['attentions.0'] = 'self_attn'
+ mappings['attentions.1'] = 'cross_attn'
+ mappings['ffns.0'] = 'ffn'
+ mappings['bbox_head.kpt_branches'] = 'bbox_head.reg_branches'
+ mappings['bbox_head.refine_kpt_branches'] = 'bbox_head.kpt_branches'
+ mappings['refine_decoder'] = 'kpt_decoder'
+ mappings['bbox_head.fc_hm'] = 'bbox_head.heatmap_fc'
+ mappings['enc_output_norm'] = 'memory_trans_norm'
+ mappings['enc_output'] = 'memory_trans_fc'
+
+ # convert old-version state dict
+ for old_key, new_key in mappings.items():
+ keys = list(state_dict.keys())
+ for k in keys:
+ if old_key in k:
+ v = state_dict.pop(k)
+ k = k.replace(old_key, new_key)
+ state_dict[k] = v
diff --git a/projects/petr/models/petr_head.py b/projects/petr/models/petr_head.py
new file mode 100644
index 0000000000..8aa4c2d950
--- /dev/null
+++ b/projects/petr/models/petr_head.py
@@ -0,0 +1,581 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from typing import List, Tuple
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import Linear
+from mmdet.models import inverse_sigmoid
+from mmdet.models.dense_heads import DeformableDETRHead
+from mmdet.models.utils import multi_apply
+from mmdet.registry import MODELS
+from mmdet.structures import SampleList
+from mmdet.utils import InstanceList, reduce_mean
+from mmengine.model import bias_init_with_prob, constant_init, normal_init
+from mmengine.structures import InstanceData
+from torch import Tensor
+
+
+@MODELS.register_module()
+class PETRHead(DeformableDETRHead):
+ r"""Head of PETR: End-to-End Multi-Person Pose Estimation
+ with Transformers.
+
+ Code is modified from the `official github repo
+ `_.
+
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ num_keypoints (int): Number of keypoints. Defaults to 17.
+ num_pred_kpt_layer (int): The number of the keypoint refine decoder
+ layers. Defaults to 2.
+ loss_reg (dict): The configuration dict of regression loss for
+ outputs of refine decoders.
+ loss_reg_aux (dict): The configuration dict of regression loss for
+ outputs of decoders.
+ loss_oks (dict): The configuration dict of oks loss for outputs
+ of refine decoders.
+ loss_oks_aux (dict): The configuration dict of oks loss for
+ outputs of decoders.
+ loss_hm (dict): The configuration dict of heatmap loss.
+ """
+
+ def __init__(self,
+ num_keypoints: int = 17,
+ num_pred_kpt_layer: int = 2,
+ loss_reg: dict = None,
+ loss_reg_aux: dict = None,
+ loss_oks: dict = None,
+ loss_oks_aux: dict = None,
+ loss_hm: dict = None,
+ *args,
+ **kwargs):
+ self.num_keypoints = num_keypoints
+ self.num_pred_kpt_layer = num_pred_kpt_layer
+ super().__init__(*args, **kwargs)
+
+ self._target_buffer = dict()
+
+ self.loss_reg = MODELS.build(loss_reg)
+ self.loss_reg_aux = MODELS.build(loss_reg_aux)
+ self.loss_oks = MODELS.build(loss_oks)
+ self.loss_oks_aux = MODELS.build(loss_oks_aux)
+ self.loss_hm = MODELS.build(loss_hm)
+
+ def _init_layers(self) -> None:
+ """Initialize classification branch and regression branch of head."""
+ fc_cls = Linear(self.embed_dims, self.cls_out_channels)
+
+ reg_branch = [Linear(self.embed_dims, self.embed_dims * 2), nn.ReLU()]
+ for _ in range(self.num_reg_fcs):
+ reg_branch.append(Linear(self.embed_dims * 2, self.embed_dims * 2))
+ reg_branch.append(nn.ReLU())
+ reg_branch.append(Linear(self.embed_dims * 2, self.num_keypoints * 2))
+ reg_branch = nn.Sequential(*reg_branch)
+
+ kpt_branch = []
+ for _ in range(self.num_reg_fcs):
+ kpt_branch.append(Linear(self.embed_dims, self.embed_dims))
+ kpt_branch.append(nn.ReLU())
+ kpt_branch.append(Linear(self.embed_dims, 2))
+ kpt_branch = nn.Sequential(*kpt_branch)
+
+ if self.share_pred_layer:
+ self.cls_branches = nn.ModuleList(
+ [fc_cls for _ in range(self.num_pred_layer)])
+ self.reg_branches = nn.ModuleList(
+ [reg_branch for _ in range(self.num_pred_layer)])
+ self.kpt_branches = nn.ModuleList(
+ [kpt_branch for _ in range(self.num_pred_kpt_layer)])
+ else:
+ self.cls_branches = nn.ModuleList(
+ [copy.deepcopy(fc_cls) for _ in range(self.num_pred_layer)])
+ self.reg_branches = nn.ModuleList([
+ copy.deepcopy(reg_branch) for _ in range(self.num_pred_layer)
+ ])
+ self.kpt_branches = nn.ModuleList([
+ copy.deepcopy(kpt_branch)
+ for _ in range(self.num_pred_kpt_layer)
+ ])
+
+ self.heatmap_fc = Linear(self.embed_dims, self.num_keypoints)
+
+ def init_weights(self) -> None:
+ """Initialize weights of the PETR head."""
+ if self.loss_cls.use_sigmoid:
+ bias_init = bias_init_with_prob(0.01)
+ for m in self.cls_branches:
+ nn.init.constant_(m.bias, bias_init)
+ for m in self.reg_branches:
+ constant_init(m[-1], 0, bias=0)
+ for m in self.kpt_branches:
+ constant_init(m[-1], 0, bias=0)
+ # initialize bias for heatmap prediction
+ bias_init = bias_init_with_prob(0.1)
+ normal_init(self.heatmap_fc, std=0.01, bias=bias_init)
+
+ def forward(self, hidden_states: Tensor,
+ references: List[Tensor]) -> Tuple[Tensor]:
+ """Forward function."""
+ all_layers_outputs_classes = []
+ all_layers_outputs_coords = []
+
+ for layer_id in range(hidden_states.shape[0]):
+ reference = inverse_sigmoid(references[layer_id])
+ # NOTE The last reference will not be used.
+ hidden_state = hidden_states[layer_id]
+ outputs_class = self.cls_branches[layer_id](hidden_state)
+ tmp_reg_preds = self.reg_branches[layer_id](hidden_state)
+ if reference.shape[-1] == self.num_keypoints * 2:
+ # When `layer` is 0 and `as_two_stage` of the detector
+ # is `True`, or when `layer` is greater than 0 and
+ # `with_box_refine` of the detector is `True`.
+ tmp_reg_preds += reference
+ else:
+ # When `layer` is 0 and `as_two_stage` of the detector
+ # is `False`, or when `layer` is greater than 0 and
+ # `with_box_refine` of the detector is `False`.
+ assert reference.shape[-1] == 2
+ tmp_reg_preds[..., :2] += reference
+ outputs_coord = tmp_reg_preds.sigmoid()
+ all_layers_outputs_classes.append(outputs_class)
+ all_layers_outputs_coords.append(outputs_coord)
+
+ all_layers_outputs_classes = torch.stack(all_layers_outputs_classes)
+ all_layers_outputs_coords = torch.stack(all_layers_outputs_coords)
+
+ return all_layers_outputs_classes, all_layers_outputs_coords
+
+ def predict(self,
+ dec_outputs_coord: Tensor,
+ det_labels: Tensor,
+ det_scores: Tensor,
+ batch_data_samples: SampleList,
+ rescale: bool = True) -> InstanceList:
+ """Perform forward propagation and loss calculation of the detection
+ head on the queries of the upstream network."""
+ batch_img_metas = [
+ data_samples.metainfo for data_samples in batch_data_samples
+ ]
+ assert len(batch_img_metas) == 1, f'PETR only support test with ' \
+ f'batch size 1, but got {len(batch_img_metas)}.'
+ img_meta = batch_img_metas[0]
+
+ if dec_outputs_coord.ndim == 4:
+ dec_outputs_coord = dec_outputs_coord[-1]
+
+ # filter instance
+ if self.test_cfg.get('score_thr', .0) > 0:
+ kept_inds = det_scores > self.test_cfg['score_thr']
+ det_labels = det_labels[kept_inds]
+ det_scores = det_scores[kept_inds]
+ dec_outputs_coord = dec_outputs_coord[kept_inds]
+
+ if len(dec_outputs_coord) > 0:
+ # decode keypoints
+ h, w = img_meta['img_shape']
+ if rescale:
+ h = h / img_meta['scale_factor'][0]
+ w = w / img_meta['scale_factor'][1]
+ keypoints = torch.stack([
+ dec_outputs_coord[..., 0] * w,
+ dec_outputs_coord[..., 1] * h,
+ ],
+ dim=2)
+ keypoint_scores = torch.ones(keypoints.shape[:-1])
+
+ # generate bounding boxes by outlining the detected poses
+ bboxes = torch.stack([
+ keypoints[..., 0].min(dim=1).values.clamp(0, w),
+ keypoints[..., 1].min(dim=1).values.clamp(0, h),
+ keypoints[..., 0].max(dim=1).values.clamp(0, w),
+ keypoints[..., 1].max(dim=1).values.clamp(0, h),
+ ],
+ dim=1)
+ else:
+ keypoints = torch.empty(0, *dec_outputs_coord.shape[1:])
+ keypoint_scores = torch.ones(keypoints.shape[:-1])
+ bboxes = torch.empty(0, 4)
+
+ results = InstanceData()
+ results.set_metainfo(img_meta)
+ results.bboxes = bboxes
+ results.scores = det_scores
+ results.bbox_scores = det_scores
+ results.labels = det_labels
+ results.keypoints = keypoints
+ results.keypoint_scores = keypoint_scores
+ results = results.numpy()
+
+ return [results]
+
+ def loss(self, enc_outputs_class: Tensor, enc_outputs_coord: Tensor,
+ all_layers_classes: Tensor, all_layers_coords: Tensor,
+ hm_memory: Tensor, hm_mask: Tensor, dec_outputs_coord: Tensor,
+ batch_data_samples: SampleList) -> dict:
+ """Perform forward propagation and loss calculation of the detection
+ head on the queries of the upstream network."""
+
+ batch_gt_instances = []
+ batch_img_metas = []
+ batch_gt_fields = []
+ for data_sample in batch_data_samples:
+ batch_img_metas.append(data_sample.metainfo)
+ batch_gt_instances.append(data_sample.gt_instances)
+ batch_gt_fields.append(data_sample.gt_fields)
+
+ loss_dict = dict()
+
+ # calculate loss for decoder output
+ losses_cls, losses_kpt, losses_oks = multi_apply(
+ self.loss_by_feat_single,
+ all_layers_classes,
+ all_layers_coords,
+ batch_gt_instances=batch_gt_instances,
+ batch_img_metas=batch_img_metas)
+ # loss from the last decoder layer
+ loss_dict['loss_cls'] = losses_cls[-1]
+ loss_dict['loss_kpt'] = losses_kpt[-1]
+ loss_dict['loss_oks'] = losses_oks[-1]
+ # loss from other decoder layers
+ num_dec_layer = 0
+ for loss_cls_i, loss_kpt_i, loss_oks_i in zip(losses_cls[:-1],
+ losses_kpt[:-1],
+ losses_oks[:-1]):
+ loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
+ loss_dict[f'd{num_dec_layer}.loss_kpt'] = loss_kpt_i
+ loss_dict[f'd{num_dec_layer}.loss_oks'] = loss_oks_i
+ num_dec_layer += 1
+
+ # calculate loss for encoder output
+ losses_cls, losses_kpt, _ = self.loss_by_feat_single(
+ enc_outputs_class,
+ enc_outputs_coord,
+ batch_gt_instances=batch_gt_instances,
+ batch_img_metas=batch_img_metas,
+ compute_oks_loss=False)
+ loss_dict['enc_loss_cls'] = losses_cls
+ loss_dict['enc_loss_kpt'] = losses_kpt
+
+ # calculate heatmap loss
+ loss_hm = self.loss_heatmap(
+ hm_memory, hm_mask, batch_gt_fields=batch_gt_fields)
+ loss_dict['loss_hm'] = loss_hm
+
+ # calculate loss for kpt_decoder output
+ losses_kpt, losses_oks = multi_apply(
+ self.loss_refined_kpts,
+ dec_outputs_coord,
+ batch_img_metas=batch_img_metas,
+ )
+
+ num_dec_layer = 0
+ for loss_kpt_i, loss_oks_i in zip(losses_kpt, losses_oks):
+ loss_dict[f'd{num_dec_layer}.loss_kpt_refine'] = loss_kpt_i
+ loss_dict[f'd{num_dec_layer}.loss_oks_refine'] = loss_oks_i
+ num_dec_layer += 1
+
+ self._target_buffer.clear()
+ return loss_dict
+
+ # TODO: rename this method
+ def loss_by_feat_single(self,
+ cls_scores: Tensor,
+ kpt_preds: Tensor,
+ batch_gt_instances: InstanceList,
+ batch_img_metas: List[dict],
+ compute_oks_loss: bool = True) -> Tuple[Tensor]:
+ """Loss function for outputs from a single decoder layer of a single
+ feature level."""
+
+ # cls_scores [bs, num_queries, 1]
+ # kpt_preds [bs, num_queries, 2*num_keypoitns]
+
+ num_imgs, num_queries = cls_scores.shape[:2]
+
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
+ kpt_preds_list = [
+ kpt_preds[i].reshape(-1, self.num_keypoints, 2)
+ for i in range(num_imgs)
+ ]
+ cls_reg_targets = self.get_targets(cls_scores_list, kpt_preds_list,
+ batch_gt_instances, batch_img_metas)
+ (labels_list, label_weights_list, bbox_targets_list, kpt_targets_list,
+ kpt_weights_list, area_targets_list, num_total_pos,
+ num_total_neg) = cls_reg_targets
+ labels = torch.cat(labels_list, 0) # [bs*300]
+ label_weights = torch.cat(label_weights_list, 0) # [bs*300] (all 1)
+ kpt_targets = torch.cat(kpt_targets_list,
+ 0) # [bs*300, 17, 2] (normalized)
+ kpt_weights = torch.cat(kpt_weights_list, 0) # [bs*300, 17]
+ area_targets = torch.cat(area_targets_list, 0) # [bs*300]
+
+ # keypoint regression loss
+ kpt_preds = kpt_preds.reshape(-1, self.num_keypoints, 2)
+ num_valid_kpt = torch.clamp(
+ reduce_mean(kpt_weights.sum()), min=1).item()
+ # assert num_valid_kpt == (kpt_targets>0).sum().item()
+ loss_kpt = self.loss_reg_aux(
+ kpt_preds,
+ kpt_targets,
+ kpt_weights.unsqueeze(-1),
+ avg_factor=num_valid_kpt)
+
+ # classification loss
+ cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
+ # construct weighted avg_factor to match with the official DETR repo
+ cls_avg_factor = num_total_pos * 1.0 + \
+ num_total_neg * self.bg_cls_weight
+ if self.sync_cls_avg_factor:
+ cls_avg_factor = reduce_mean(
+ cls_scores.new_tensor([cls_avg_factor]))
+ cls_avg_factor = max(cls_avg_factor, 1)
+ loss_cls = self.loss_cls(
+ cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
+
+ pos_mask = kpt_weights.sum(-1) > 0
+ if compute_oks_loss and pos_mask.any().item():
+ pos_inds = (pos_mask.nonzero()).div(
+ num_queries, rounding_mode='trunc').squeeze(-1)
+
+ # construct factors used for rescale keypoints
+ factors = []
+ for img_meta, kpt_pred in zip(batch_img_metas, kpt_preds):
+ img_h, img_w, = img_meta['img_shape']
+ factor = kpt_pred.new_tensor([img_w, img_h]).reshape(1, 1, 2)
+ factors.append(factor)
+ factors = torch.cat(factors, 0)
+ factors = factors[pos_inds]
+
+ # keypoint oks loss
+ pos_kpt_preds = kpt_preds[pos_mask] * factors
+ pos_kpt_targets = kpt_targets[pos_mask] * factors
+ pos_kpt_weights = kpt_weights[pos_mask]
+ pos_area_targets = area_targets[pos_mask]
+
+ loss_oks = self.loss_oks_aux(pos_kpt_preds, pos_kpt_targets,
+ pos_kpt_weights, pos_area_targets)
+ else:
+ loss_oks = torch.zeros_like(loss_kpt)
+
+ return loss_cls, loss_kpt, loss_oks
+
+ def loss_refined_kpts(self, kpt_preds: Tensor,
+ batch_img_metas: List[dict]) -> Tuple[Tensor]:
+ """Loss function for outputs from a single decoder layer of a single
+ feature level."""
+
+ # kpt_preds [num_selected, num_keypoints, 2]
+ kpt_targets_list = self._target_buffer['kpt_targets_list']
+ kpt_weights_list = self._target_buffer['kpt_weights_list']
+ area_targets_list = self._target_buffer['area_targets_list']
+ num_queries = len(kpt_targets_list[0])
+ kpt_targets = torch.cat(kpt_targets_list,
+ 0).contiguous() # [bs*300, 17, 2] (normalized)
+ kpt_weights = torch.cat(kpt_weights_list,
+ 0).contiguous() # [bs*300, 17]
+ area_targets = torch.cat(area_targets_list, 0).contiguous() # [bs*300]
+
+ pos_mask = (kpt_weights.sum(-1) > 0).contiguous()
+ pos_inds = (pos_mask.nonzero()).div(
+ num_queries, rounding_mode='trunc').squeeze(-1)
+
+ # keypoint regression loss
+ kpt_preds = kpt_preds.reshape(-1, self.num_keypoints, 2)
+ num_valid_kpt = torch.clamp(
+ reduce_mean(kpt_weights.sum()), min=1).item()
+ # assert num_valid_kpt == (kpt_targets>0).sum().item()
+ loss_kpt = self.loss_reg(
+ kpt_preds,
+ kpt_targets[pos_mask],
+ kpt_weights[pos_mask].unsqueeze(-1),
+ avg_factor=num_valid_kpt)
+
+ if pos_mask.any().item():
+ # construct factors used for rescale keypoints
+ factors = []
+ for img_meta in batch_img_metas:
+ img_h, img_w, = img_meta['img_shape']
+ factor = kpt_preds.new_tensor([img_w, img_h]).reshape(1, 1, 2)
+ factors.append(factor)
+ factors = torch.cat(factors, 0)
+
+ factors = factors[pos_inds]
+
+ # keypoint oks loss
+ pos_kpt_preds = kpt_preds * factors
+ pos_kpt_targets = kpt_targets[pos_mask]
+
+ pos_kpt_targets = pos_kpt_targets * factors
+ pos_kpt_weights = kpt_weights[pos_mask]
+ # pos_bbox_targets = (bbox_targets[pos_mask].reshape(-1, 2, 2) *
+ # factors).flatten(-2)
+ pos_area_targets = area_targets[pos_mask]
+
+ loss_oks = self.loss_oks_aux(pos_kpt_preds, pos_kpt_targets,
+ pos_kpt_weights, pos_area_targets)
+ else:
+ loss_oks = torch.zeros_like(loss_kpt)
+
+ return loss_kpt, loss_oks
+
+ def loss_heatmap(self, hm_memory, hm_mask, batch_gt_fields):
+ """Heatmap loss function for outputs from the heatmap encoder."""
+
+ # compute heatmap predition
+ pred_heatmaps = self.heatmap_fc(hm_memory)
+ pred_heatmaps = torch.clamp(
+ pred_heatmaps.sigmoid_(), min=1e-4, max=1 - 1e-4)
+ pred_heatmaps = pred_heatmaps.permute(0, 3, 1, 2).contiguous()
+
+ # construct heatmap target
+ gt_heatmaps = torch.zeros_like(pred_heatmaps)
+ for i, gf in enumerate(batch_gt_fields):
+ gt_heatmap = gf.gt_heatmaps
+ h = min(gt_heatmap.size(1), gt_heatmaps.size(2))
+ w = min(gt_heatmap.size(2), gt_heatmaps.size(3))
+ gt_heatmaps[i, :, :h, :w] = gt_heatmap[:, :h, :w]
+
+ loss_hm = self.loss_hm(pred_heatmaps, gt_heatmaps, None,
+ 1 - hm_mask.unsqueeze(1).float())
+ return loss_hm
+
+ @torch.no_grad()
+ def get_targets(self,
+ cls_scores_list: List[Tensor],
+ kpt_preds_list: List[Tensor],
+ batch_gt_instances: InstanceList,
+ batch_img_metas: List[dict],
+ cache_targets: bool = False) -> tuple:
+ """Compute regression and classification targets for a batch image.
+
+ Outputs from a single decoder layer of a single feature level are used.
+
+ Args:
+ cls_scores_list (list[Tensor]): Box score logits from a single
+ decoder layer for each image, has shape [num_queries,
+ cls_out_channels].
+ kpt_preds_list (list[Tensor]): Sigmoid outputs from a single
+ decoder layer for each image, with normalized coordinate
+ (cx, cy, w, h) and shape [num_queries, 4].
+ batch_gt_instances (list[:obj:`InstanceData`]): Batch of
+ gt_instance. It usually includes ``bboxes`` and ``labels``
+ attributes.
+ batch_img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+
+ Returns:
+ tuple: a tuple containing the following targets.
+
+ - labels_list (list[Tensor]): Labels for all images.
+ - label_weights_list (list[Tensor]): Label weights for all images.
+ - bbox_targets_list (list[Tensor]): BBox targets for all images.
+ - kpt_targets_list (list[Tensor]): Keypoint targets for all images.
+ - kpt_weights_list (list[Tensor]): Keypoint weights for all images.
+ - num_total_pos (int): Number of positive samples in all images.
+ - num_total_neg (int): Number of negative samples in all images.
+ """
+ (labels_list, label_weights_list, bbox_targets_list, kpt_targets_list,
+ kpt_weights_list, area_targets_list, pos_inds_list,
+ neg_inds_list) = multi_apply(self._get_targets_single,
+ cls_scores_list, kpt_preds_list,
+ batch_gt_instances, batch_img_metas)
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+
+ if cache_targets:
+ self._target_buffer['labels_list'] = labels_list
+ self._target_buffer['label_weights_list'] = label_weights_list
+ self._target_buffer['bbox_targets_list'] = bbox_targets_list
+ self._target_buffer['kpt_targets_list'] = kpt_targets_list
+ self._target_buffer['kpt_weights_list'] = kpt_weights_list
+ self._target_buffer['area_targets_list'] = area_targets_list
+ self._target_buffer['num_total_pos'] = num_total_pos
+ self._target_buffer['num_total_neg'] = num_total_neg
+
+ return (labels_list, label_weights_list, bbox_targets_list,
+ kpt_targets_list, kpt_weights_list, area_targets_list,
+ num_total_pos, num_total_neg)
+
+ def _get_targets_single(self, cls_score: Tensor, kpt_pred: Tensor,
+ gt_instances: InstanceData,
+ img_meta: dict) -> tuple:
+ """Compute regression and classification targets for one image.
+
+ Outputs from a single decoder layer of a single feature level are used.
+
+ Args:
+ cls_score (Tensor): Box score logits from a single decoder layer
+ for one image. Shape [num_queries, cls_out_channels].
+ kpt_pred (Tensor): Sigmoid outputs from a single decoder layer
+ for one image, with normalized coordinate (cx, cy, w, h) and
+ shape [num_queries, 4].
+ gt_instances (:obj:`InstanceData`): Ground truth of instance
+ annotations. It should includes ``bboxes`` and ``labels``
+ attributes.
+ img_meta (dict): Meta information for one image.
+
+ Returns:
+ tuple[Tensor]: a tuple containing the following for one image.
+
+ - labels (Tensor): Labels of each image.
+ - label_weights (Tensor]): Label weights of each image.
+ - bbox_targets (Tensor): BBox targets of each image.
+ - kpt_targets (Tensor): Keypoint targets of each image.
+ - kpt_weights (Tensor): Keypoint weights of each image.
+ - pos_inds (Tensor): Sampled positive indices for each image.
+ - neg_inds (Tensor): Sampled negative indices for each image.
+ """
+ img_h, img_w = img_meta['img_shape']
+ num_insts = kpt_pred.size(0)
+ factor = kpt_pred.new_tensor([img_w, img_h]).unsqueeze(0).unsqueeze(1)
+ kpt_pred = kpt_pred * factor
+
+ pred_instances = InstanceData(scores=cls_score, keypoints=kpt_pred)
+ # assigner and sampler
+ assign_result = self.assigner.assign(
+ pred_instances=pred_instances,
+ gt_instances=gt_instances,
+ img_meta=img_meta)
+
+ gt_keypoints = gt_instances.keypoints
+ gt_keypoints_visible = gt_instances.keypoints_visible
+ gt_labels = gt_instances.labels
+ gt_bboxes = gt_instances.bboxes
+ gt_areas = gt_instances.area
+ pos_inds = torch.nonzero(
+ assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
+ neg_inds = torch.nonzero(
+ assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
+ pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
+
+ # label targets
+ labels = gt_labels.new_full((num_insts, ),
+ self.num_classes,
+ dtype=torch.long)
+ labels[pos_inds] = gt_labels[pos_assigned_gt_inds]
+ label_weights = gt_labels.new_ones(num_insts)
+
+ # kpt targets
+ kpt_targets = torch.zeros_like(kpt_pred)
+ pos_gt_keypoints = gt_keypoints[pos_assigned_gt_inds.long(), :]
+ kpt_targets[pos_inds] = pos_gt_keypoints / factor
+ kpt_weights = torch.zeros_like(kpt_pred).narrow(-1, 0, 1).squeeze(-1)
+ pos_gt_keypoints_visible = gt_keypoints_visible[
+ pos_assigned_gt_inds.long()]
+ kpt_weights[pos_inds] = (pos_gt_keypoints_visible > 0).float()
+
+ # bbox_targets, which is used to compute oks loss
+ bbox_targets = torch.zeros_like(kpt_pred).narrow(-2, 0, 2)
+ pos_gt_bboxes = gt_bboxes[pos_assigned_gt_inds.long()]
+ bbox_targets[pos_inds] = pos_gt_bboxes.reshape(
+ *pos_gt_bboxes.shape[:-1], 2, 2) / factor
+ bbox_targets = bbox_targets.flatten(-2)
+ area_targets = kpt_pred.new_zeros(kpt_pred.shape[0])
+ area_targets[pos_inds] = gt_areas[pos_assigned_gt_inds.long()].float()
+
+ return (labels, label_weights, bbox_targets, kpt_targets, kpt_weights,
+ area_targets, pos_inds, neg_inds)
diff --git a/projects/petr/models/transformers.py b/projects/petr/models/transformers.py
new file mode 100644
index 0000000000..ff24316ff0
--- /dev/null
+++ b/projects/petr/models/transformers.py
@@ -0,0 +1,256 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional, Tuple, no_type_check
+
+import torch
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
+from mmcv.ops import MultiScaleDeformableAttention
+from mmcv.ops.multi_scale_deform_attn import (
+ MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch)
+from mmdet.models.layers import (DeformableDetrTransformerDecoder,
+ DeformableDetrTransformerDecoderLayer)
+from mmdet.models.layers.transformer.utils import inverse_sigmoid
+from mmengine.model import ModuleList
+from mmengine.utils import deprecated_api_warning
+from torch import Tensor, nn
+
+
+class MultiScaleDeformablePoseAttention(MultiScaleDeformableAttention):
+
+ @no_type_check
+ @deprecated_api_warning({'residual': 'identity'},
+ cls_name='MultiScaleDeformablePoseAttention')
+ def forward(self,
+ query: torch.Tensor,
+ key: Optional[torch.Tensor] = None,
+ value: Optional[torch.Tensor] = None,
+ identity: Optional[torch.Tensor] = None,
+ query_pos: Optional[torch.Tensor] = None,
+ key_padding_mask: Optional[torch.Tensor] = None,
+ reference_points: Optional[torch.Tensor] = None,
+ spatial_shapes: Optional[torch.Tensor] = None,
+ level_start_index: Optional[torch.Tensor] = None,
+ **kwargs) -> torch.Tensor:
+ """Forward Function of MultiScaleDeformAttention.
+
+ Args:
+ query (torch.Tensor): Query of Transformer with shape
+ (num_query, bs, embed_dims).
+ key (torch.Tensor): The key tensor with shape
+ `(num_key, bs, embed_dims)`.
+ value (torch.Tensor): The value tensor with shape
+ `(num_key, bs, embed_dims)`.
+ identity (torch.Tensor): The tensor used for addition, with the
+ same shape as `query`. Default None. If None,
+ `query` will be used.
+ query_pos (torch.Tensor): The positional encoding for `query`.
+ Default: None.
+ key_padding_mask (torch.Tensor): ByteTensor for `query`, with
+ shape [bs, num_key].
+ reference_points (torch.Tensor): The normalized reference
+ points with shape (bs, num_query, num_levels, 2),
+ all elements is range in [0, 1], top-left (0,0),
+ bottom-right (1, 1), including padding area.
+ or (N, Length_{query}, num_levels, 4), add
+ additional two dimensions is (w, h) to
+ form reference boxes.
+ spatial_shapes (torch.Tensor): Spatial shape of features in
+ different levels. With shape (num_levels, 2),
+ last dimension represents (h, w).
+ level_start_index (torch.Tensor): The start index of each level.
+ A tensor has shape ``(num_levels, )`` and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+
+ Returns:
+ torch.Tensor: forwarded results with shape
+ [num_query, bs, embed_dims].
+ """
+
+ if value is None:
+ value = query
+
+ if identity is None:
+ identity = query
+ if query_pos is not None:
+ query = query + query_pos
+ if not self.batch_first:
+ # change to (bs, num_query ,embed_dims)
+ query = query.permute(1, 0, 2)
+ value = value.permute(1, 0, 2)
+
+ bs, num_query, _ = query.shape
+ bs, num_value, _ = value.shape
+ assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value
+
+ value = self.value_proj(value)
+ if key_padding_mask is not None:
+ value = value.masked_fill(key_padding_mask[..., None], 0.0)
+ value = value.view(bs, num_value, self.num_heads, -1)
+ sampling_offsets = self.sampling_offsets(query).view(
+ bs, num_query, self.num_heads, self.num_levels, self.num_points, 2)
+ attention_weights = self.attention_weights(query).view(
+ bs, num_query, self.num_heads, self.num_levels * self.num_points)
+ attention_weights = attention_weights.softmax(-1)
+
+ attention_weights = attention_weights.view(bs, num_query,
+ self.num_heads,
+ self.num_levels,
+ self.num_points)
+ if reference_points.shape[-1] == self.num_points * 2:
+ reference_points_reshape = reference_points.reshape(
+ bs, num_query, self.num_levels, -1, 2).unsqueeze(2)
+ x1 = reference_points[:, :, :, 0::2].min(dim=-1, keepdim=True)[0]
+ y1 = reference_points[:, :, :, 1::2].min(dim=-1, keepdim=True)[0]
+ x2 = reference_points[:, :, :, 0::2].max(dim=-1, keepdim=True)[0]
+ y2 = reference_points[:, :, :, 1::2].max(dim=-1, keepdim=True)[0]
+ w = torch.clamp(x2 - x1, min=1e-4)
+ h = torch.clamp(y2 - y1, min=1e-4)
+ wh = torch.cat([w, h], dim=-1)[:, :, None, :, None, :]
+ sampling_locations = reference_points_reshape \
+ + sampling_offsets * wh * 0.5
+
+ else:
+ raise ValueError(
+ f'Last dim of reference_points must be {self.num_points*2}, '
+ f'but get {reference_points.shape[-1]} instead.')
+ if torch.cuda.is_available() and value.is_cuda:
+ output = MultiScaleDeformableAttnFunction.apply(
+ value, spatial_shapes, level_start_index, sampling_locations,
+ attention_weights, self.im2col_step)
+ else:
+ output = multi_scale_deformable_attn_pytorch(
+ value, spatial_shapes, sampling_locations, attention_weights)
+
+ output = self.output_proj(output)
+
+ if not self.batch_first:
+ # (num_query, bs ,embed_dims)
+ output = output.permute(1, 0, 2)
+
+ return self.dropout(output) + identity
+
+
+class PetrTransformerDecoderLayer(DeformableDetrTransformerDecoderLayer):
+ """Decoder layer of PETR."""
+
+ def _init_layers(self) -> None:
+ """Initialize self_attn, cross-attn, ffn, and norms."""
+ self.self_attn = MultiheadAttention(**self.self_attn_cfg)
+ self.cross_attn = MultiScaleDeformablePoseAttention(
+ **self.cross_attn_cfg)
+ self.embed_dims = self.self_attn.embed_dims
+ self.ffn = FFN(**self.ffn_cfg)
+ norms_list = [
+ build_norm_layer(self.norm_cfg, self.embed_dims)[1]
+ for _ in range(3)
+ ]
+ self.norms = ModuleList(norms_list)
+
+
+class PetrTransformerDecoder(DeformableDetrTransformerDecoder):
+ """Transformer Decoder of PETR."""
+
+ def __init__(self, num_keypoints: int, *args, **kwargs):
+ self.num_keypoints = num_keypoints
+ super().__init__(*args, **kwargs)
+
+ def _init_layers(self) -> None:
+ """Initialize decoder layers."""
+ self.layers = ModuleList([
+ PetrTransformerDecoderLayer(**self.layer_cfg)
+ for _ in range(self.num_layers)
+ ])
+ self.embed_dims = self.layers[0].embed_dims
+ if self.post_norm_cfg is not None:
+ raise ValueError('There is not post_norm in '
+ f'{self._get_name()}')
+
+ def forward(self,
+ query: Tensor,
+ query_pos: Tensor,
+ value: Tensor,
+ key_padding_mask: Tensor,
+ reference_points: Tensor,
+ spatial_shapes: Tensor,
+ level_start_index: Tensor,
+ valid_ratios: Tensor,
+ reg_branches: Optional[nn.Module] = None,
+ **kwargs) -> Tuple[Tensor]:
+ """Forward function of Transformer decoder.
+
+ Args:
+ query (Tensor): The input queries, has shape (bs, num_queries,
+ dim).
+ query_pos (Tensor): The input positional query, has shape
+ (bs, num_queries, dim). It will be added to `query` before
+ forward function.
+ value (Tensor): The input values, has shape (bs, num_value, dim).
+ key_padding_mask (Tensor): The `key_padding_mask` of `cross_attn`
+ input. ByteTensor, has shape (bs, num_value).
+ reference_points (Tensor): The initial reference, has shape
+ (bs, num_queries, 4) with the last dimension arranged as
+ (cx, cy, w, h) when `as_two_stage` is `True`, otherwise has
+ shape (bs, num_queries, 2) with the last dimension arranged
+ as (cx, cy).
+ spatial_shapes (Tensor): Spatial shapes of features in all levels,
+ has shape (num_levels, 2), last dimension represents (h, w).
+ level_start_index (Tensor): The start index of each level.
+ A tensor has shape (num_levels, ) and can be represented
+ as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...].
+ valid_ratios (Tensor): The ratios of the valid width and the valid
+ height relative to the width and the height of features in all
+ levels, has shape (bs, num_levels, 2).
+ reg_branches: (obj:`nn.ModuleList`, optional): Used for refining
+ the regression results. Only would be passed when
+ `with_box_refine` is `True`, otherwise would be `None`.
+
+ Returns:
+ tuple[Tensor]: Outputs of Deformable Transformer Decoder.
+
+ - output (Tensor): Output embeddings of the last decoder, has
+ shape (num_queries, bs, embed_dims) when `return_intermediate`
+ is `False`. Otherwise, Intermediate output embeddings of all
+ decoder layers, has shape (num_decoder_layers, num_queries, bs,
+ embed_dims).
+ - reference_points (Tensor): The reference of the last decoder
+ layer, has shape (bs, num_queries, 4) when `return_intermediate`
+ is `False`. Otherwise, Intermediate references of all decoder
+ layers, has shape (num_decoder_layers, bs, num_queries, 4). The
+ coordinates are arranged as (cx, cy, w, h)
+ """
+ output = query
+ intermediate = []
+ intermediate_reference_points = []
+ for layer_id, layer in enumerate(self.layers):
+ assert reference_points.shape[-1] == self.num_keypoints * 2
+ reference_points_input = \
+ reference_points[:, :, None] * \
+ valid_ratios.repeat(1, 1, self.num_keypoints)[:, None]
+
+ output = layer(
+ output,
+ query_pos=query_pos,
+ value=value,
+ key_padding_mask=key_padding_mask,
+ spatial_shapes=spatial_shapes,
+ level_start_index=level_start_index,
+ valid_ratios=valid_ratios,
+ reference_points=reference_points_input,
+ **kwargs)
+
+ if reg_branches is not None:
+ tmp_reg_preds = reg_branches[layer_id](output)
+ new_reference_points = tmp_reg_preds + inverse_sigmoid(
+ reference_points)
+ new_reference_points = new_reference_points.sigmoid()
+ reference_points = new_reference_points.detach()
+
+ if self.return_intermediate:
+ intermediate.append(output)
+ intermediate_reference_points.append(reference_points)
+
+ if self.return_intermediate:
+ return torch.stack(intermediate), torch.stack(
+ intermediate_reference_points)
+
+ return output, reference_points
diff --git a/projects/petr/tools b/projects/petr/tools
new file mode 120000
index 0000000000..31941e941d
--- /dev/null
+++ b/projects/petr/tools
@@ -0,0 +1 @@
+../../tools
\ No newline at end of file