Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
zbwxp authored Aug 9, 2023
1 parent 97c19df commit aeecb34
Show file tree
Hide file tree
Showing 47 changed files with 4,267 additions and 0 deletions.
70 changes: 70 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Official Pytorch Implementation of DToP

### Dynamic Token Pruning in Plain Vision Transformers for Semantic Segmentation
Quan Tang, Bowen Zhang, Jiajun Liu, Fagiu Liu, Yifan Liu

ICCV 2023. [[arxiv]](https://arxiv.org/abs/2308.01045)

This repository contains the official Pytorch implementation of training & evaluation code and the pretrained models for DToP

As shown in the following figure, the network is naturally split into stages using inherent auxiliary blocks.

<img src="./resources/fig-1-1.png">

## Highlights
* **Dynamic Token Pruning** We introduce a dynamic token pruning paradigm based on the early exit of easy-to-recognize tokens for semantic segmentation transformers.
* **Controllable prune ratio** One hyperparameter to control the trade-off between computation cost and accuracy.
* **Generally applicable** e apply DToP to mainstream semantic segmentation transformers and can reduce up to 35% computational cost without a notable accuracy drop.

## Getting started
1. requirements
```
torch==2.0.0 mmcls==1.0.0.rc5, mmcv==2.0.0 mmengine==0.7.0 mmsegmentation==1.0.0rc6
```
or up-to-date mmxx series till 9 Aug 2023

## Training
To aquire the base model
```
python tools dist_train.sh config/prune/BASE_segvit_ade20k_large.py $work_dirs$
```
To prune on the base model
```
python tools dist_train_load.sh config/prune/prune_segvit_ade20k_large.py $work_dirs$ $path_to_ckpt$
```

## Eval
```
python tools/dist_test.sh config/prune/prune_segvit_ade20k_large.py $path_to_ckpt$
```

## Datasets
Please follow the instructions of [mmsegmentation](https://github.com/open-mmlab/mmsegmentation) data preparation

## Results
### Ade20k
| Method | Backbone | mIoU | GFlops | config | ckpt |
|--------------|-----------|------|--------|--------|------|
| Segvit | Vit-base | 49.6 | 109.9 | [config](./config/prune/BASE_segvit_ade20k.py) | |
| Segvit-prune | Vit-base | 49.8 | 86.8 | [config](./config/prune/prune_segvit_ade20k.py) | |
| Segvit | Vit-large | 53.3 | 617.0 | [config](./config/prune/BASE_segvit_ade20k_large.py) | |
| Segvit-prune | Vit-large | 52.8 | 412.8 | [config](./config/prune/prune_segvit_ade20k_large.py) | |

### Pascal Context
| Method | Backbone | mIoU | GFlops | config | ckpt |
|--------------|-----------|------|--------|--------|------|
| Segvit | Vit-large | 63.0 | 315.4 | [config](./config/prune/BASE_segvit_pc.py) | |
| Segvit-prune | Vit-large | 62.7 | 224.3 | [config](./config/prune/prune_segvit_pc.py) | |

### COCO-Stuff-10K
| Method | Backbone | mIoU | GFlops | config | ckpt |
|--------------|-----------|------|--------|--------|------|
| Segvit | Vit-large | 47.4 | 366.9 | [config](./config/prune/BASE_segvit_cocostuff10k.py) | |
| Segvit-prune | Vit-large | 47.1 | 276.2 | [config](./config/prune/prune_segvit_cocostuff10k.py) | |



## License
For academic use, this project is licensed under the 2-clause BSD License - see the LICENSE file for details. For commercial use, please contact the authors.

## Citation
68 changes: 68 additions & 0 deletions config/_base_/datasets/ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = '/data/ADEChallengeData2016'
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(
type='RandomResize',
scale=(2048, 512),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', backend_args=None),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training', seg_map_path='annotations/training'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/validation',
seg_map_path='annotations/validation'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
68 changes: 68 additions & 0 deletions config/_base_/datasets/ade20k_640x640.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# dataset settings
dataset_type = 'ADE20KDataset'
data_root = '/data/ADEChallengeData2016'
crop_size = (640, 640)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(
type='RandomResize',
scale=(2560, 640),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2560, 640), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', backend_args=None),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/training', seg_map_path='annotations/training'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='images/validation',
seg_map_path='annotations/validation'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
69 changes: 69 additions & 0 deletions config/_base_/datasets/coco-stuff10k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# dataset settings
dataset_type = 'COCOStuffDataset'
data_root = '/data/coco_stuff10k'
crop_size = (512, 512)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(
type='RandomResize',
scale=(2048, 512),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 512), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', backend_args=None),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
reduce_zero_label=True,
data_prefix=dict(
img_path='images/train2014', seg_map_path='annotations/train2014'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
reduce_zero_label=True,
data_prefix=dict(
img_path='images/test2014', seg_map_path='annotations/test2014'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
72 changes: 72 additions & 0 deletions config/_base_/datasets/pascal_context_59.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# dataset settings
dataset_type = 'PascalContextDataset59'
data_root = '/data/VOCdevkit/VOC2010/'

img_scale = (2048, 520) # compromise for mmengin Resize
crop_size = (480, 480)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(
type='RandomResize',
scale=img_scale,
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=img_scale, keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', backend_args=None),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
ann_file='ImageSets/SegmentationContext/train.txt',
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='JPEGImages', seg_map_path='SegmentationClassContext'),
ann_file='ImageSets/SegmentationContext/val.txt',
pipeline=test_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator
15 changes: 15 additions & 0 deletions config/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'),
)
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer')
log_processor = dict(by_epoch=False)
log_level = 'INFO'
load_from = None
resume = False

tta_model = dict(type='SegTTAModel')
Loading

0 comments on commit aeecb34

Please sign in to comment.