An unofficial implementation of "RMT: Retentive Networks Meet Vision Transformers. I created this repo to exercise my paper-to-code translation skill while waiting for the official implementation to be published on: https://github.com/qhfan/RMT.
RMT is an architecture that adopts the retention mechanism proposed by Sun et al. in the paper "Retentive Network: A Successor to Transformer for Large Language Models", which capably serves as a general-purpose backbone for computer vision. It extends the usability of retention mechanism from unidirectional, one-dimensional data (sequential data like texts) to bidirectional, two-dimensional data (images). Moreover, unlike the original Retentive Network, RMT does not apply the different-representation scenario for training and inference as the recurrent form greatly disrupts the parallelism of the model that results in a very slow inference speed.
RMT achieves strong performance on COCO object detection (51.6 box AP
and 45.9 mask AP
) and ADE20K semantic segmentation (52.0 mIoU
), surpassing previous models by a huge margin.
This repo was created by forking the mmpretrain repo (mmpretrain). Update the description inside the readme file.
Below are quick steps for installation:
conda create -n open-mmlab python=3.8 pytorch==1.10.1 torchvision==0.11.2 cudatoolkit=11.3 -c pytorch -y
conda activate open-mmlab
pip install openmim
git clone https://github.com/open-mmlab/mmpretrain.git
cd mmpretrain
mim install -e .
Please refer to installation documentation for more detailed installation and dataset preparation.
to train the RMT model, you can use tools/train.py. Here is the full usage of the script:
python tools/train.py ${CONFIG_FILE} [ARGS]
where CONFIG_FILE is the path to the config file. There are some predefined config files available inside the configs/rmt
folder. One example is rmt-tiny_b128_cifar10.py
where it runs the tiny configuration of rmt with batch size of 128 of the CIFAR10 dataset. Please refer to these tutorials about the basic usage of MMPretrain for new users:
- Learn about Configs
- Prepare Dataset
- Inference with existing models
- Train
- Test
- Downstream tasks
- MMPretrain Documentation.
MMPreTrain is an open source project that is contributed by researchers and engineers from various colleges and companies. Appreciation to all the contributors who implement their methods or add new features, as well as users who give valuable feedbacks. I also would like to thank the authors for writing such a wonderful paper.
If you find this project useful in your research, please consider cite:
@misc{rmt-unofficial,
title={RMT Unofficial Implementation},
author={Farros Alferro},
howpublished = {\url{https://github.com/farrosalferro/RMT-unofficial}},
year={2023}
}
This project is released under the Apache 2.0 license.