-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
47 changed files
with
4,267 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Official Pytorch Implementation of DToP | ||
|
||
### Dynamic Token Pruning in Plain Vision Transformers for Semantic Segmentation | ||
Quan Tang, Bowen Zhang, Jiajun Liu, Fagiu Liu, Yifan Liu | ||
|
||
ICCV 2023. [[arxiv]](https://arxiv.org/abs/2308.01045) | ||
|
||
This repository contains the official Pytorch implementation of training & evaluation code and the pretrained models for DToP | ||
|
||
As shown in the following figure, the network is naturally split into stages using inherent auxiliary blocks. | ||
|
||
<img src="./resources/fig-1-1.png"> | ||
|
||
## Highlights | ||
* **Dynamic Token Pruning** We introduce a dynamic token pruning paradigm based on the early exit of easy-to-recognize tokens for semantic segmentation transformers. | ||
* **Controllable prune ratio** One hyperparameter to control the trade-off between computation cost and accuracy. | ||
* **Generally applicable** e apply DToP to mainstream semantic segmentation transformers and can reduce up to 35% computational cost without a notable accuracy drop. | ||
|
||
## Getting started | ||
1. requirements | ||
``` | ||
torch==2.0.0 mmcls==1.0.0.rc5, mmcv==2.0.0 mmengine==0.7.0 mmsegmentation==1.0.0rc6 | ||
``` | ||
or up-to-date mmxx series till 9 Aug 2023 | ||
|
||
## Training | ||
To aquire the base model | ||
``` | ||
python tools dist_train.sh config/prune/BASE_segvit_ade20k_large.py $work_dirs$ | ||
``` | ||
To prune on the base model | ||
``` | ||
python tools dist_train_load.sh config/prune/prune_segvit_ade20k_large.py $work_dirs$ $path_to_ckpt$ | ||
``` | ||
|
||
## Eval | ||
``` | ||
python tools/dist_test.sh config/prune/prune_segvit_ade20k_large.py $path_to_ckpt$ | ||
``` | ||
|
||
## Datasets | ||
Please follow the instructions of [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) data preparation | ||
|
||
## Results | ||
### Ade20k | ||
| Method | Backbone | mIoU | GFlops | config | ckpt | | ||
|--------------|-----------|------|--------|--------|------| | ||
| Segvit | Vit-base | 49.6 | 109.9 | [config](./config/prune/BASE_segvit_ade20k.py) | | | ||
| Segvit-prune | Vit-base | 49.8 | 86.8 | [config](./config/prune/prune_segvit_ade20k.py) | | | ||
| Segvit | Vit-large | 53.3 | 617.0 | [config](./config/prune/BASE_segvit_ade20k_large.py) | | | ||
| Segvit-prune | Vit-large | 52.8 | 412.8 | [config](./config/prune/prune_segvit_ade20k_large.py) | | | ||
|
||
### Pascal Context | ||
| Method | Backbone | mIoU | GFlops | config | ckpt | | ||
|--------------|-----------|------|--------|--------|------| | ||
| Segvit | Vit-large | 63.0 | 315.4 | [config](./config/prune/BASE_segvit_pc.py) | | | ||
| Segvit-prune | Vit-large | 62.7 | 224.3 | [config](./config/prune/prune_segvit_pc.py) | | | ||
|
||
### COCO-Stuff-10K | ||
| Method | Backbone | mIoU | GFlops | config | ckpt | | ||
|--------------|-----------|------|--------|--------|------| | ||
| Segvit | Vit-large | 47.4 | 366.9 | [config](./config/prune/BASE_segvit_cocostuff10k.py) | | | ||
| Segvit-prune | Vit-large | 47.1 | 276.2 | [config](./config/prune/prune_segvit_cocostuff10k.py) | | | ||
|
||
|
||
|
||
## License | ||
For academic use, this project is licensed under the 2-clause BSD License - see the LICENSE file for details. For commercial use, please contact the authors. | ||
|
||
## Citation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# dataset settings | ||
dataset_type = 'ADE20KDataset' | ||
data_root = '/data/ADEChallengeData2016' | ||
crop_size = (512, 512) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations', reduce_zero_label=True), | ||
dict( | ||
type='RandomResize', | ||
scale=(2048, 512), | ||
ratio_range=(0.5, 2.0), | ||
keep_ratio=True), | ||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PhotoMetricDistortion'), | ||
dict(type='PackSegInputs') | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Resize', scale=(2048, 512), keep_ratio=True), | ||
# add loading annotation after ``Resize`` because ground truth | ||
# does not need to do resize data transform | ||
dict(type='LoadAnnotations', reduce_zero_label=True), | ||
dict(type='PackSegInputs') | ||
] | ||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] | ||
tta_pipeline = [ | ||
dict(type='LoadImageFromFile', backend_args=None), | ||
dict( | ||
type='TestTimeAug', | ||
transforms=[ | ||
[ | ||
dict(type='Resize', scale_factor=r, keep_ratio=True) | ||
for r in img_ratios | ||
], | ||
[ | ||
dict(type='RandomFlip', prob=0., direction='horizontal'), | ||
dict(type='RandomFlip', prob=1., direction='horizontal') | ||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] | ||
]) | ||
] | ||
train_dataloader = dict( | ||
batch_size=4, | ||
num_workers=4, | ||
persistent_workers=True, | ||
sampler=dict(type='InfiniteSampler', shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_prefix=dict( | ||
img_path='images/training', seg_map_path='annotations/training'), | ||
pipeline=train_pipeline)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=4, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_prefix=dict( | ||
img_path='images/validation', | ||
seg_map_path='annotations/validation'), | ||
pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# dataset settings | ||
dataset_type = 'ADE20KDataset' | ||
data_root = '/data/ADEChallengeData2016' | ||
crop_size = (640, 640) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations', reduce_zero_label=True), | ||
dict( | ||
type='RandomResize', | ||
scale=(2560, 640), | ||
ratio_range=(0.5, 2.0), | ||
keep_ratio=True), | ||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PhotoMetricDistortion'), | ||
dict(type='PackSegInputs') | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Resize', scale=(2560, 640), keep_ratio=True), | ||
# add loading annotation after ``Resize`` because ground truth | ||
# does not need to do resize data transform | ||
dict(type='LoadAnnotations', reduce_zero_label=True), | ||
dict(type='PackSegInputs') | ||
] | ||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] | ||
tta_pipeline = [ | ||
dict(type='LoadImageFromFile', backend_args=None), | ||
dict( | ||
type='TestTimeAug', | ||
transforms=[ | ||
[ | ||
dict(type='Resize', scale_factor=r, keep_ratio=True) | ||
for r in img_ratios | ||
], | ||
[ | ||
dict(type='RandomFlip', prob=0., direction='horizontal'), | ||
dict(type='RandomFlip', prob=1., direction='horizontal') | ||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] | ||
]) | ||
] | ||
train_dataloader = dict( | ||
batch_size=4, | ||
num_workers=4, | ||
persistent_workers=True, | ||
sampler=dict(type='InfiniteSampler', shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_prefix=dict( | ||
img_path='images/training', seg_map_path='annotations/training'), | ||
pipeline=train_pipeline)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=4, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_prefix=dict( | ||
img_path='images/validation', | ||
seg_map_path='annotations/validation'), | ||
pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# dataset settings | ||
dataset_type = 'COCOStuffDataset' | ||
data_root = '/data/coco_stuff10k' | ||
crop_size = (512, 512) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations', reduce_zero_label=True), | ||
dict( | ||
type='RandomResize', | ||
scale=(2048, 512), | ||
ratio_range=(0.5, 2.0), | ||
keep_ratio=True), | ||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PhotoMetricDistortion'), | ||
dict(type='PackSegInputs') | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Resize', scale=(2048, 512), keep_ratio=True), | ||
# add loading annotation after ``Resize`` because ground truth | ||
# does not need to do resize data transform | ||
dict(type='LoadAnnotations', reduce_zero_label=True), | ||
dict(type='PackSegInputs') | ||
] | ||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] | ||
tta_pipeline = [ | ||
dict(type='LoadImageFromFile', backend_args=None), | ||
dict( | ||
type='TestTimeAug', | ||
transforms=[ | ||
[ | ||
dict(type='Resize', scale_factor=r, keep_ratio=True) | ||
for r in img_ratios | ||
], | ||
[ | ||
dict(type='RandomFlip', prob=0., direction='horizontal'), | ||
dict(type='RandomFlip', prob=1., direction='horizontal') | ||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] | ||
]) | ||
] | ||
train_dataloader = dict( | ||
batch_size=4, | ||
num_workers=4, | ||
persistent_workers=True, | ||
sampler=dict(type='InfiniteSampler', shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
reduce_zero_label=True, | ||
data_prefix=dict( | ||
img_path='images/train2014', seg_map_path='annotations/train2014'), | ||
pipeline=train_pipeline)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=4, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
reduce_zero_label=True, | ||
data_prefix=dict( | ||
img_path='images/test2014', seg_map_path='annotations/test2014'), | ||
pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# dataset settings | ||
dataset_type = 'PascalContextDataset59' | ||
data_root = '/data/VOCdevkit/VOC2010/' | ||
|
||
img_scale = (2048, 520) # compromise for mmengin Resize | ||
crop_size = (480, 480) | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations', reduce_zero_label=True), | ||
dict( | ||
type='RandomResize', | ||
scale=img_scale, | ||
ratio_range=(0.5, 2.0), | ||
keep_ratio=True), | ||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PhotoMetricDistortion'), | ||
dict(type='PackSegInputs') | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Resize', scale=img_scale, keep_ratio=True), | ||
# add loading annotation after ``Resize`` because ground truth | ||
# does not need to do resize data transform | ||
dict(type='LoadAnnotations', reduce_zero_label=True), | ||
dict(type='PackSegInputs') | ||
] | ||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] | ||
tta_pipeline = [ | ||
dict(type='LoadImageFromFile', backend_args=None), | ||
dict( | ||
type='TestTimeAug', | ||
transforms=[ | ||
[ | ||
dict(type='Resize', scale_factor=r, keep_ratio=True) | ||
for r in img_ratios | ||
], | ||
[ | ||
dict(type='RandomFlip', prob=0., direction='horizontal'), | ||
dict(type='RandomFlip', prob=1., direction='horizontal') | ||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] | ||
]) | ||
] | ||
train_dataloader = dict( | ||
batch_size=4, | ||
num_workers=4, | ||
persistent_workers=True, | ||
sampler=dict(type='InfiniteSampler', shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_prefix=dict( | ||
img_path='JPEGImages', seg_map_path='SegmentationClassContext'), | ||
ann_file='ImageSets/SegmentationContext/train.txt', | ||
pipeline=train_pipeline)) | ||
val_dataloader = dict( | ||
batch_size=1, | ||
num_workers=4, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
data_prefix=dict( | ||
img_path='JPEGImages', seg_map_path='SegmentationClassContext'), | ||
ann_file='ImageSets/SegmentationContext/val.txt', | ||
pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
default_scope = 'mmseg' | ||
env_cfg = dict( | ||
cudnn_benchmark=True, | ||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | ||
dist_cfg=dict(backend='nccl'), | ||
) | ||
vis_backends = [dict(type='LocalVisBackend')] | ||
visualizer = dict( | ||
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer') | ||
log_processor = dict(by_epoch=False) | ||
log_level = 'INFO' | ||
load_from = None | ||
resume = False | ||
|
||
tta_model = dict(type='SegTTAModel') |
Oops, something went wrong.