-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
2,856 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# PETR | ||
|
||
This project implements PETR (Pose Estimation with TRansformers), an end-to-end multi-person pose estimation framework introduced in the CVPR 2022 paper **End-to-End Multi-Person Pose Estimation with Transformers**. PETR is a novel, end-to-end multi-person pose estimation method that treats pose estimation as a hierarchical set prediction problem. By leveraging attention mechanisms, PETR can adaptively focus on features most relevant to target keypoints, thereby overcoming feature misalignment issues in pose estimation. | ||
|
||
<img src="https://github.com/open-mmlab/mmpose/assets/26127467/ec7eb99d-8b8b-4c0d-9714-0ccd33a4f054" alt><br> | ||
|
||
## Usage | ||
|
||
### Prerequisites | ||
|
||
- Python 3.7 or higher | ||
- PyTorch 1.6 or higher | ||
- [MMEngine](https://github.com/open-mmlab/mmengine) v0.7.0 or higher | ||
- [MMCV](https://github.com/open-mmlab/mmcv) v2.0.0 or higher | ||
- [MMDetection](https://github.com/open-mmlab/mmdetection) v3.0.0 or higher | ||
- [MMPose](https://github.com/open-mmlab/mmpose) v1.0.0 or higher | ||
|
||
All the commands below rely on the correct configuration of `PYTHONPATH`, which should point to the project's directory so that Python can locate the module files. **In `petr/` root directory**, run the following line to add the current directory to `PYTHONPATH`: | ||
|
||
```shell | ||
export PYTHONPATH=`pwd`:$PYTHONPATH | ||
``` | ||
|
||
### Inference | ||
|
||
Users can apply PETR models to estimate human poses using the inferencer found in the MMPose core package. Use the command below: | ||
|
||
```shell | ||
python demo/inferencer_demo.py $INPUTS \ | ||
--pose2d $CONFIG --pose2d-weights $CHECKPOINT --scope mmdet \ | ||
[--show] [--vis-out-dir $VIS_OUT_DIR] [--pred-out-dir $PRED_OUT_DIR] | ||
``` | ||
|
||
For more information on using the inferencer, please see [this document](https://mmpose.readthedocs.io/en/latest/user_guides/inference.html#out-of-the-box-inferencer). | ||
|
||
### Results | ||
|
||
Results on COCO val2017 | ||
|
||
| Model | mAP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | Checkpoint | Log | | ||
| :----------------------------------------------: | :--: | :-------------: | :-------------: | :--: | :-------------: | :---------------------------------------------------: | :--------------------------------------------: | | ||
| [PETR-R50](/configs/petr_r50_8xb3-100e_coco.py) | 68.7 | 86.2 | 76.0 | 76.5 | 91.5 | [ckpt](https://download.openmmlab.com/mmpose/v1/projects/petr/petr_r50_8xb3-100e_coco-520803d9_20230613.pth) | [log](https://download.openmmlab.com/mmpose/v1/projects/petr/petr_r50_8xb3-100e_coco-20230613.json) | | ||
| [PETR-R50](/configs/petr_r50_8xb4-100e_coco.py)\* | 68.7 | 87.5 | 76.2 | 75.9 | 92.1 | [Google Drive](https://drive.google.com/file/d/1HcwraqWdZ3CaGMQOJHY8exNem7UnFkfS/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1C0HbQWV7K-GHQE7q34nUZw?pwd=u798) | / | | ||
| [PETR-R101](/configs/petr_r101_8xb4-100e_coco.py)\* | 70.0 | 88.5 | 77.5 | 77.0 | 92.6 | [Google Drive](https://drive.google.com/file/d/1O261Jrt4JRGlIKTmLtPy3AUruwX1hsDf/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1D5wqNP53KNOKKE5NnO2Dnw?pwd=keyn) | / | | ||
| [PETR-Swin](/configs/petr_swin-l_8xb2-100e_coco.py)\* | 73.0 | 90.7 | 80.9 | 80.1 | 94.5 | [Google Drive](https://drive.google.com/file/d/1ujL0Gm5tPjweT0-gdDGkTc7xXrEt6gBP/view?usp=sharing) \| [BaiduYun](https://pan.baidu.com/s/1X5Cdq75GosRCKqbHZTSpJQ?pwd=t9ea) | / | | ||
|
||
\* Indicates that the checkpoints are sourced from the official repository. The training accuracy is still not up to the results reported in the paper. We will continue to update this project after aligning the training accuracy. | ||
|
||
## Citation | ||
|
||
If this project benefits your work, please kindly consider citing the original papers: | ||
|
||
```bibtex | ||
@inproceedings{shi2022end, | ||
title={End-to-end multi-person pose estimation with transformers}, | ||
author={Shi, Dahu and Wei, Xing and Li, Liangqi and Ren, Ye and Tan, Wenming}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, | ||
pages={11069--11078}, | ||
year={2022} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../../../configs/_base_/datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
default_scope = 'mmdet' | ||
custom_imports = dict(imports=['models', 'datasets']) | ||
|
||
# hooks | ||
default_hooks = dict( | ||
timer=dict(type='IterTimerHook'), | ||
logger=dict(type='LoggerHook', interval=50), | ||
param_scheduler=dict(type='ParamSchedulerHook'), | ||
checkpoint=dict(type='CheckpointHook', interval=5, max_keep_ckpts=3), | ||
sampler_seed=dict(type='DistSamplerSeedHook'), | ||
visualization=dict(type='mmpose.PoseVisualizationHook', enable=False), | ||
) | ||
|
||
# multi-processing backend | ||
env_cfg = dict( | ||
cudnn_benchmark=False, | ||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | ||
dist_cfg=dict(backend='nccl'), | ||
) | ||
|
||
# visualizer | ||
vis_backends = [dict(type='LocalVisBackend')] | ||
visualizer = dict( | ||
type='mmpose.PoseLocalVisualizer', | ||
vis_backends=vis_backends, | ||
name='visualizer') | ||
|
||
# logger | ||
log_processor = dict( | ||
type='LogProcessor', window_size=50, by_epoch=True, num_digits=6) | ||
log_level = 'INFO' | ||
load_from = None | ||
resume = False | ||
|
||
# file I/O backend | ||
backend_args = dict(backend='local') | ||
|
||
# training/validation/testing progress | ||
train_cfg = dict() | ||
val_cfg = dict(type='ValLoop') | ||
test_cfg = dict(type='TestLoop') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
_base_ = ['petr_r50_8xb4-100e_coco.py'] | ||
|
||
# model | ||
checkpoint = 'https://download.openmmlab.com/mmpose/mmdet_pretrained/' \ | ||
'deformable_detr_twostage_refine_r101_16x2_50e_coco-3186d66b_20230613.pth' | ||
|
||
model = dict(init_cfg=dict(checkpoint=checkpoint), backbone=dict(depth=101)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
_base_ = ['petr_r50_8xb4-100e_coco.py'] | ||
|
||
auto_scale_lr = dict(base_batch_size=24) | ||
train_dataloader = dict(batch_size=3) | ||
optim_wrapper = dict(optimizer=dict(lr=0.00015)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,263 @@ | ||
_base_ = ['./_base_/default_runtime.py'] | ||
|
||
# learning policy | ||
max_epochs = 100 | ||
train_cfg = dict( | ||
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=5) | ||
val_cfg = dict(type='ValLoop') | ||
test_cfg = dict(type='TestLoop') | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=max_epochs, | ||
by_epoch=True, | ||
milestones=[80], | ||
gamma=0.1) | ||
] | ||
|
||
# NOTE: `auto_scale_lr` is for automatically scaling LR, | ||
auto_scale_lr = dict(base_batch_size=32) | ||
|
||
# optimizer | ||
optim_wrapper = dict( | ||
type='OptimWrapper', | ||
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001), | ||
clip_grad=dict(max_norm=0.1, norm_type=2), | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'backbone': dict(lr_mult=0.1), | ||
'sampling_offsets': dict(lr_mult=0.1), | ||
'reference_points': dict(lr_mult=0.1) | ||
})) | ||
|
||
# model | ||
num_keypoints = 17 | ||
checkpoint = 'https://download.openmmlab.com/mmdetection/v2.0/deformable_' \ | ||
'detr/deformable_detr_twostage_refine_r50_16x2_50e_coco/deformable_' \ | ||
'detr_twostage_refine_r50_16x2_50e_coco_20210419_220613-9d28ab72.pth' | ||
model = dict( | ||
type='PETR', | ||
num_queries=300, | ||
num_feature_levels=4, | ||
num_keypoints=num_keypoints, | ||
with_box_refine=True, | ||
as_two_stage=True, | ||
init_cfg=dict( | ||
type='Pretrained', | ||
checkpoint=checkpoint, | ||
), | ||
data_preprocessor=dict( | ||
type='DetDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_size_divisor=1), | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type='BN', requires_grad=False), | ||
norm_eval=True, | ||
style='pytorch'), | ||
neck=dict( | ||
type='ChannelMapper', | ||
in_channels=[512, 1024, 2048], | ||
kernel_size=1, | ||
out_channels=256, | ||
act_cfg=None, | ||
norm_cfg=dict(type='GN', num_groups=32), | ||
num_outs=4), | ||
encoder=dict( # DeformableDetrTransformerEncoder | ||
num_layers=6, | ||
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer | ||
self_attn_cfg=dict( # MultiScaleDeformableAttention | ||
embed_dims=256, | ||
batch_first=True), | ||
ffn_cfg=dict( | ||
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))), | ||
decoder=dict( # PetrTransformerDecoder | ||
num_layers=3, | ||
num_keypoints=num_keypoints, | ||
return_intermediate=True, | ||
layer_cfg=dict( # PetrTransformerDecoderLayer | ||
self_attn_cfg=dict( # MultiheadAttention | ||
embed_dims=256, | ||
num_heads=8, | ||
dropout=0.1, | ||
batch_first=True), | ||
cross_attn_cfg=dict( # MultiScaleDeformablePoseAttention | ||
embed_dims=256, | ||
num_points=num_keypoints, | ||
batch_first=True), | ||
ffn_cfg=dict( | ||
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)), | ||
post_norm_cfg=None), | ||
hm_encoder=dict( # DeformableDetrTransformerEncoder | ||
num_layers=1, | ||
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer | ||
self_attn_cfg=dict( # MultiScaleDeformableAttention | ||
embed_dims=256, | ||
num_levels=1, | ||
batch_first=True), | ||
ffn_cfg=dict( | ||
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))), | ||
kpt_decoder=dict( # DeformableDetrTransformerDecoder | ||
num_layers=2, | ||
return_intermediate=True, | ||
layer_cfg=dict( # DetrTransformerDecoderLayer | ||
self_attn_cfg=dict( # MultiheadAttention | ||
embed_dims=256, | ||
num_heads=8, | ||
dropout=0.1, | ||
batch_first=True), | ||
cross_attn_cfg=dict( # MultiScaleDeformableAttention | ||
embed_dims=256, | ||
im2col_step=128), | ||
ffn_cfg=dict( | ||
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)), | ||
post_norm_cfg=None), | ||
positional_encoding=dict(num_feats=128, normalize=True, offset=-0.5), | ||
bbox_head=dict( | ||
type='PETRHead', | ||
num_classes=1, | ||
num_keypoints=num_keypoints, | ||
sync_cls_avg_factor=True, | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=2.0), | ||
loss_reg=dict(type='L1Loss', loss_weight=40.0), | ||
loss_reg_aux=dict(type='L1Loss', loss_weight=35.0), | ||
loss_oks=dict( | ||
type='OksLoss', | ||
metainfo='configs/_base_/datasets/coco.py', | ||
loss_weight=3.0), | ||
loss_oks_aux=dict( | ||
type='OksLoss', | ||
metainfo='configs/_base_/datasets/coco.py', | ||
loss_weight=2.0), | ||
loss_hm=dict(type='mmpose.FocalHeatmapLoss', loss_weight=4.0), | ||
), | ||
# training and testing settings | ||
train_cfg=dict( | ||
assigner=dict( | ||
type='HungarianAssigner', | ||
match_costs=[ | ||
dict(type='FocalLossCost', weight=2.0), | ||
dict(type='KptL1Cost', weight=70.0), | ||
dict( | ||
type='OksCost', | ||
metainfo='configs/_base_/datasets/coco.py', | ||
weight=7.0) | ||
])), | ||
test_cfg=dict( | ||
max_per_img=100, | ||
score_thr=0.0, | ||
)) | ||
|
||
train_pipeline = [ | ||
dict(type='mmpose.LoadImage', backend_args=_base_.backend_args), | ||
dict(type='PoseToDetConverter'), | ||
dict(type='PhotoMetricDistortion'), | ||
dict( | ||
type='RandomAffine', | ||
max_rotate_degree=30.0, | ||
max_translate_ratio=0., | ||
scaling_ratio_range=(1., 1.), | ||
max_shear_degree=0., | ||
border_val=[103.53, 116.28, 123.675], | ||
), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict( | ||
type='RandomChoice', | ||
transforms=[ | ||
[ | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=list(zip(range(400, 1401, 8), (1400, ) * 126)), | ||
keep_ratio=True) | ||
], | ||
[ | ||
dict( | ||
type='RandomChoiceResize', | ||
# The radio of all image in train dataset < 7 | ||
# follow the original implement | ||
scales=[(400, 4200), (500, 4200), (600, 4200)], | ||
keep_ratio=True), | ||
dict( | ||
type='RandomCrop', | ||
crop_type='absolute_range', | ||
crop_size=(384, 600), | ||
allow_negative_crop=True), | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=list(zip(range(400, 1401, 8), (1400, ) * 126)), | ||
keep_ratio=True) | ||
] | ||
]), | ||
dict(type='FilterDetPoseAnnotations', keep_empty=False), | ||
dict(type='GenerateHeatmap'), | ||
dict( | ||
type='PackDetPoseInputs', | ||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape')) | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='mmpose.LoadImage', backend_args=_base_.backend_args), | ||
dict(type='PoseToDetConverter'), | ||
dict(type='Resize', scale=(1333, 800), keep_ratio=True), | ||
dict( | ||
type='PackDetPoseInputs', | ||
meta_keys=('id', 'img_id', 'img_path', 'ori_shape', 'img_shape', | ||
'scale_factor', 'flip_indices')) | ||
] | ||
|
||
dataset_type = 'CocoDataset' | ||
data_mode = 'bottomup' | ||
data_root = 'data/coco/' | ||
|
||
train_dataloader = dict( | ||
batch_size=4, | ||
num_workers=4, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
batch_sampler=dict(type='AspectRatioBatchSampler'), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_mode=data_mode, | ||
data_root=data_root, | ||
ann_file='annotations/person_keypoints_train2017.json', | ||
data_prefix=dict(img='train2017/'), | ||
filter_cfg=dict(filter_empty_gt=True, min_size=32), | ||
pipeline=train_pipeline)) | ||
|
||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_mode=data_mode, | ||
data_root=data_root, | ||
ann_file='annotations/person_keypoints_val2017.json', | ||
data_prefix=dict(img='val2017/'), | ||
test_mode=True, | ||
pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict( | ||
type='mmpose.CocoMetric', | ||
ann_file=data_root + 'annotations/person_keypoints_val2017.json', | ||
nms_mode='none', | ||
score_mode='bbox') | ||
test_evaluator = val_evaluator | ||
|
||
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater')) |
Oops, something went wrong.