Skip to content

Commit

Permalink
[Feature] barlowtwins (#207)
Browse files Browse the repository at this point in the history
* [Fix]: Fix mmcls upgrade bug (#235)

* [Feature]: Add multi machine dist_train (#232)

* [Feature]: Add multi machine dist_train

* [Fix]: Change bash to sh

* [Fix]: Fix missing sh suffix

* [Refactor]: Change bash to sh

* [Refactor] Add unit test (#234)

* [Refactor] add unit test

* update workflow

* update

* [Fix] fix lint

* update test

* refactor moco and densecl unit test

* fix lint

* add unit test

* update unit test

* remove modification

* [Feature]: Add MAE metafile (#238)

* [Feature]: Add MAE metafile

* [Fix]: Fix lint

* [Fix]: Change LARS to AdamW in the metafile of MAE

* Add barlowtwins

* Add unit test for barlowtwins

* Adjust training params

* add decorator to pass CI

* adjust params

* Add barlowtwins

* Add unit test for barlowtwins

* Adjust training params

* add decorator to pass CI

* adjust params

* add barlowtwins configs

* revise LatentCrossCorrelationHead

* modify ut to save memory

* add metafile

* add barlowtwins results to model zoo

* add barlow twins to homepage

* fix batch size bug

* add algorithm readme

* add type hints

* reorganize the model zoo

* remove one config

* recover the config

* add missing docstring

* revise barlowtwins

* reorganize coco and voc benchmark

* add barlowtwins to index.rst

* revise docstring

Co-authored-by: Yuan Liu <[email protected]>
Co-authored-by: Yixiao Fang <[email protected]>
Co-authored-by: fangyixiao18 <[email protected]>
  • Loading branch information
4 people authored Apr 27, 2022
1 parent b959934 commit c525544
Show file tree
Hide file tree
Showing 21 changed files with 549 additions and 92 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ Supported algorithms:
- [x] [SwAV (NeurIPS'2020)](https://arxiv.org/abs/2006.09882)
- [x] [DenseCL (CVPR'2021)](https://arxiv.org/abs/2011.09157)
- [x] [SimSiam (CVPR'2021)](https://arxiv.org/abs/2011.10566)
- [x] [Barlow Twins (ICML'2021)](https://arxiv.org/abs/2103.03230)
- [x] [MoCo v3 (ICCV'2021)](https://arxiv.org/abs/2104.02057)
- [x] [MAE](https://arxiv.org/abs/2111.06377)
- [x] [SimMIM](https://arxiv.org/abs/2111.09886)
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ MMSelfSup 和 OpenSelfSup 的不同点写在 [对比文档](docs/en/compatibilit
- [x] [SwAV (NeurIPS'2020)](https://arxiv.org/abs/2006.09882)
- [x] [DenseCL (CVPR'2021)](https://arxiv.org/abs/2011.09157)
- [x] [SimSiam (CVPR'2021)](https://arxiv.org/abs/2011.10566)
- [x] [Barlow Twins (ICML'2021)](https://arxiv.org/abs/2103.03230)
- [x] [MoCo v3 (ICCV'2021)](https://arxiv.org/abs/2104.02057)
- [x] [MAE](https://arxiv.org/abs/2111.06377)
- [x] [SimMIM](https://arxiv.org/abs/2111.09886)
Expand Down
22 changes: 22 additions & 0 deletions configs/selfsup/_base_/models/barlowtwins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# model settings
model = dict(
type='BarlowTwins',
backbone=dict(
type='ResNet',
depth=50,
in_channels=3,
out_indices=[4], # 0: conv-1, x: stage-x
norm_cfg=dict(type='SyncBN'),
zero_init_residual=True),
neck=dict(
type='NonLinearNeck',
in_channels=2048,
hid_channels=8192,
out_channels=8192,
num_layers=3,
with_last_bn=False,
with_last_bn_affine=False,
with_avg_pool=True,
init_cfg=dict(
type='Kaiming', distribution='uniform', layer=['Linear'])),
head=dict(type='LatentCrossCorrelationHead', in_channels=8192))
52 changes: 52 additions & 0 deletions configs/selfsup/barlowtwins/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# BarlowTwins

> [Barlow Twins: Self-Supervised Learning via Redundancy Reduction](https://arxiv.org/abs/2103.03230)
<!-- [ALGORITHM] -->

## Abstract

Self-supervised learning (SSL) is rapidly closing the gap with supervised methods on large computer vision benchmarks. A successful approach to SSL is to learn embeddings which are invariant to distortions of the input sample. However, a recurring issue with this approach is the existence of trivial constant solutions. Most current methods avoid such solutions by careful implementation details. We propose an objective function that naturally avoids collapse by measuring the cross-correlation matrix between the outputs of two identical networks fed with distorted versions of a sample, and making it as close to the identity matrix as possible. This causes the embedding vectors of distorted versions of a sample to be similar, while minimizing the redundancy between the components of these vectors. The method is called Barlow Twins, owing to neuroscientist H. Barlow's redundancy-reduction principle applied to a pair of identical networks. Barlow Twins does not require large batches nor asymmetry between the network twins such as a predictor network, gradient stopping, or a moving average on the weight updates. Intriguingly it benefits from very high-dimensional output vectors. Barlow Twins outperforms previous methods on ImageNet for semi-supervised classification in the low-data regime, and is on par with current state of the art for ImageNet classification with a linear classifier head, and for transfer tasks of classification and object detection.

<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/163914714-082de804-0b5f-4024-94f9-880e6ef334fa.png" width="800" />
</div>

## Results and Models

**Back to [model_zoo.md](https://github.com/open-mmlab/mmselfsup/blob/master/docs/en/model_zoo.md) to download models.**

In this page, we provide benchmarks as much as possible to evaluate our pre-trained models. If not mentioned, all models are pre-trained on ImageNet-1k dataset.

### Classification

The classification benchmarks includes 1 downstream task datasets, **ImageNet**. If not specified, the results are Top-1 (%).

#### ImageNet Linear Evaluation

The **Feature1 - Feature5** don't have the GlobalAveragePooling, the feature map is pooled to the specific dimensions and then follows a Linear layer to do the classification. Please refer to [resnet50_mhead_8xb32-steplr-90e.py](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/resnet50_mhead_8xb32-steplr-90e_in1k.py) for details of config.

The **AvgPool** result is obtained from Linear Evaluation with GlobalAveragePooling. Please refer to [resnet50_8xb32-steplr-100e_in1k.py](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/resnet50_8xb32-steplr-100e_in1k.py) for details of config.

| Self-Supervised Config | Feature1 | Feature2 | Feature3 | Feature4 | Feature5 | AvgPool |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------- | -------- | -------- | -------- | -------- | ------- |
| [barlowtwins_resnet50_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/arlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k.py) | 15.51 | 33.98 | 45.96 | 61.90 | 71.01 | 71.66 |

#### ImageNet Nearest-Neighbor Classification

The results are obtained from the features after GlobalAveragePooling. Here, k=10 to 200 indicates different number of nearest neighbors.

| Self-Supervised Config | k=10 | k=20 | k=100 | k=200 |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---- | ---- | ----- | ----- |
| [barlowtwins_resnet50_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/arlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k.py) | 63.6 | 63.8 | 62.7 | 61.9 |

## Citation

```bibtex
@inproceedings{zbontar2021barlow,
title={Barlow twins: Self-supervised learning via redundancy reduction},
author={Zbontar, Jure and Jing, Li and Misra, Ishan and LeCun, Yann and Deny, St{\'e}phane},
booktitle={International Conference on Machine Learning},
year={2021},
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = 'barlowtwins_resnet50_8xb256-coslr-300e_in1k.py'

# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=1000)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
_base_ = [
'../_base_/models/barlowtwins.py',
'../_base_/datasets/imagenet_byol.py',
'../_base_/schedules/lars_coslr-200e_in1k.py',
'../_base_/default_runtime.py',
]

data = dict(samples_per_gpu=256)

# optimizer
optimizer = dict(
type='LARS',
lr=1.6,
momentum=0.9,
weight_decay=1e-6,
paramwise_options={
'(bn|gn)(\\d+)?.(weight|bias)':
dict(weight_decay=0, lr_mult=0.024, lars_exclude=True),
'bias':
dict(weight_decay=0, lr_mult=0.024, lars_exclude=True),
# bn layer in ResNet block downsample module
'downsample.1':
dict(weight_decay=0, lr_mult=0.024, lars_exclude=True),
})

# learning policy
lr_config = dict(
policy='CosineAnnealing',
by_epoch=False,
min_lr=0.0016,
warmup='linear',
warmup_iters=10,
warmup_ratio=1.6e-4, # cannot be 0
warmup_by_epoch=True)

# runtime settings
# the max_keep_ckpts controls the max number of ckpt file in your work_dirs
# if it is 3, when CheckpointHook (in mmcv) saves the 4th ckpt
# it will remove the oldest one to keep the number of total ckpts as 3
checkpoint_config = dict(interval=10, max_keep_ckpts=3)

runner = dict(type='EpochBasedRunner', max_epochs=300)
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = 'barlowtwins_resnet50_8xb256-coslr-300e_in1k.py'

data = dict(samples_per_gpu=32)

# additional hooks
# interval for accumulate gradient, total 8*32*8(interval)=2048
update_interval = 8

# optimizer
optimizer_config = dict(update_interval=update_interval)

# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=100)
28 changes: 28 additions & 0 deletions configs/selfsup/barlowtwins/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
Collections:
- Name: BarlowTwins
Metadata:
Training Data: ImageNet-1k
Training Techniques:
- LARS
Training Resources: 8x A100 GPUs
Architecture:
- ResNet
- BarlowTwins
Paper:
URL: https://arxiv.org/abs/2103.03230
Title: "Barlow Twins: Self-Supervised Learning via Redundancy Reduction"
README: configs/selfsup/barlowtwins/README.md

Models:
- Name: barlowtwins_resnet50_8xb256-coslr-300e_in1k
In Collection: BarlowTwins
Metadata:
Epochs: 300
Batch Size: 2048
Results:
- Task: Self-Supervised Image Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 71.66
Config: configs/selfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k.py
Weights: https://download.openmmlab.com/mmselfsup/barlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k_20220419-5ae15f89.pth
52 changes: 52 additions & 0 deletions docs/en/algorithms/barlowtwins.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# BarlowTwins

> [Barlow Twins: Self-Supervised Learning via Redundancy Reduction](https://arxiv.org/abs/2103.03230)
<!-- [ALGORITHM] -->

## Abstract

Self-supervised learning (SSL) is rapidly closing the gap with supervised methods on large computer vision benchmarks. A successful approach to SSL is to learn embeddings which are invariant to distortions of the input sample. However, a recurring issue with this approach is the existence of trivial constant solutions. Most current methods avoid such solutions by careful implementation details. We propose an objective function that naturally avoids collapse by measuring the cross-correlation matrix between the outputs of two identical networks fed with distorted versions of a sample, and making it as close to the identity matrix as possible. This causes the embedding vectors of distorted versions of a sample to be similar, while minimizing the redundancy between the components of these vectors. The method is called Barlow Twins, owing to neuroscientist H. Barlow's redundancy-reduction principle applied to a pair of identical networks. Barlow Twins does not require large batches nor asymmetry between the network twins such as a predictor network, gradient stopping, or a moving average on the weight updates. Intriguingly it benefits from very high-dimensional output vectors. Barlow Twins outperforms previous methods on ImageNet for semi-supervised classification in the low-data regime, and is on par with current state of the art for ImageNet classification with a linear classifier head, and for transfer tasks of classification and object detection.

<div align="center">
<img src="https://user-images.githubusercontent.com/36138628/163914714-082de804-0b5f-4024-94f9-880e6ef334fa.png" width="800" />
</div>

## Results and Models

**Back to [model_zoo.md](https://github.com/open-mmlab/mmselfsup/blob/master/docs/en/model_zoo.md) to download models.**

In this page, we provide benchmarks as much as possible to evaluate our pre-trained models. If not mentioned, all models are pre-trained on ImageNet-1k dataset.

### Classification

The classification benchmarks includes 1 downstream task datasets, **ImageNet**. If not specified, the results are Top-1 (%).

#### ImageNet Linear Evaluation

The **Feature1 - Feature5** don't have the GlobalAveragePooling, the feature map is pooled to the specific dimensions and then follows a Linear layer to do the classification. Please refer to [resnet50_mhead_8xb32-steplr-90e.py](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/resnet50_mhead_8xb32-steplr-90e_in1k.py) for details of config.

The **AvgPool** result is obtained from Linear Evaluation with GlobalAveragePooling. Please refer to [resnet50_8xb32-steplr-100e_in1k.py](https://github.com/open-mmlab/mmselfsup/blob/master/configs/benchmarks/classification/imagenet/resnet50_8xb32-steplr-100e_in1k.py) for details of config.

| Self-Supervised Config | Feature1 | Feature2 | Feature3 | Feature4 | Feature5 | AvgPool |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------- | -------- | -------- | -------- | -------- | ------- |
| [barlowtwins_resnet50_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/arlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k.py) | 15.51 | 33.98 | 45.96 | 61.90 | 71.01 | 71.66 |

#### ImageNet Nearest-Neighbor Classification

The results are obtained from the features after GlobalAveragePooling. Here, k=10 to 200 indicates different number of nearest neighbors.

| Self-Supervised Config | k=10 | k=20 | k=100 | k=200 |
| ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---- | ---- | ----- | ----- |
| [barlowtwins_resnet50_8xb256-coslr-300e_in1k](https://github.com/open-mmlab/mmselfsup/blob/master/configs/selfsup/arlowtwins/barlowtwins_resnet50_8xb256-coslr-300e_in1k.py) | 63.6 | 63.8 | 62.7 | 61.9 |

## Citation

```bibtex
@inproceedings{zbontar2021barlow,
title={Barlow twins: Self-supervised learning via redundancy reduction},
author={Zbontar, Jure and Jing, Li and Misra, Ishan and LeCun, Yann and Deny, St{\'e}phane},
booktitle={International Conference on Machine Learning},
year={2021},
}
```
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Welcome to MMSelfSup's documentation!
algorithms/mocov3.md
algorithms/mae.md
algorithms/simmim.md
algorithms/barlowtwins.md


.. toctree::
Expand Down
Loading

0 comments on commit c525544

Please sign in to comment.