- Introduction
- Installation
- Usage
- Classification results on CV tasks
- Classification results on NLP tasks
- Visualization
- Change log
- Acknowledgments
- Contact
This repository is the official implementation of "SoT: Delving Deeper into Classification Head for Transformer". It contains the source code under PyTorch framework and models for image classification and text classification tasks.
Please consider cite the paper if it's useful for you.
@articles{SoT,
author = {Jiangtao Xie, Ruiren Zeng, Qilong Wang, Ziqi Zhou, Peihua Li},
title = {SoT: Delving Deeper into Classification Head for Transformer},
booktitle = {arXiv:2104.10935v2},
year = {2021}
}
For classification tasks whether in CV or NLP, the current works based on pure transformer architecture pay little attention to the classification head, applying Classification token (ClassT) solely in the classifier, however neglecting the Word tokens (WordT) which contains rich information. In our experiments, we show the ClassT and WordT are highly complementary, and the fusion of all tokens can further boost the performance. Therefore, we propose a novel classification paradigm by jointly utilizing ClassT and WordT, where the multiheaded global cross-covariance pooling with singluar value power normalization is proposed for effectively harness the rich information of WordT. We evaluate our proposed classfication scheme on the both CV and NLP tasks, achieving the very competitive performance with the counterparts.
- clone
git clone https://github.com/jiangtaoxie/SoT.git
cd SoT/
- install dependencies
pip install -r requirments.txt
main libs: torch(>=1.7.0) | timm(==0.3.4) | apex (alternative)
- install
python setup.py install
Please prepare the dataset as the following file structure:
.
├── train
│ ├── class1
│ │ ├── class1_001.jpg
│ │ ├── class1_002.jpg
| | └── ...
│ ├── class2
│ ├── class3
│ ├── ...
│ ├── ...
│ └── classN
└── val
├── class1
│ ├── class1_001.jpg
│ ├── class1_002.jpg
| └── ...
├── class2
├── class3
├── ...
├── ...
└── classN
- Training from scracth:
You can train the models of SoT family by using the command:
sh ./distributed_train.sh $NODE_NUM $DATA_ROOT --model $MODEL_NAME -b $BATCH_SIZE --lr $INIT_LR\
--weight-decay $WEIGHT_DECAY \
--img-size $RESOLUTION \
--amp
Basic hyper-parameter of our SoT:
Hyper-parameter | SoT-Tiny | SoT-Small | SoT-Base |
---|---|---|---|
Batch size | 1024 | 1024 | 512 |
Init. LR | 1e-3 | 1e-3 | 5e-4 |
Weight Decay | 3e-2 | 3e-2 | 6.5e-2 |
Also, we provide the shell
files in ./scripts
for reproducing conveniently, you can run:
sh ./scripts/train_SoT_Tiny.sh # reproduce SoT-Tiny
sh ./scripts/train_SoT_Small.sh # reproduce SoT-Small
sh ./scripts/train_SoT_Base.sh # reproduce SoT-Base
- Evaluation
On validation set of ImageNet-1K:
python main.py $DATA_ROOT $MODEL_NAME --b 256 --eval_checkpoint $CHECKPOINT_PATH
On ImageNet-A:
python main.py $DATA_ROOT $MODEL_NAME --b 256 --eval_checkpoint $CHECKPOINT_PATH --IN_A
The $MODEL_NAME
can be SoT_Tiny
/SoT_Small
/SoT_Base
- import the sot_src package
from sot_src.model import Classifier, OnlyVisualTokensClassifier
- define the classification head
classification_head_config = dict(
type='MGCrP',
fusion_type='sum_fc',
args=dict(
dim=256,
num_heads=6,
wr_dim=14,
normalization=dict(
type='svPN'
alpha=0.5,
iterNum=1,
svNum=1,
regular=None, # or nn.Dropout(0.5)
input_dim=14,
),
),
)
classifier = Classifier(classification_head_config)
Notes:
- if your backbone without classification token, please use
OnlyVisualTokensClassifier
to replaceClassifier
- key arguments:
- dim: equal to the embedding dimension
- wr_dim: dimension of W,R; you can control the final representation dimension by adjusting it
- regular: you can use dropout regularization to alleviate the overfitting
Besides, we provide the implementation based on the DeiT and Swin-Transformer in CV tasks and BERT in NLP tasks for reference.
You can also use the proposed TokenEmbedding module implemented by the DenseNet block like:
from sot_src import TokenEmbed
patch_embed_config = dict(
type='DenseNet',
embedding_dim=64,
large_output=False, # When the resulotion of input image is 224, Ture for the 56x56 output, False for 14x14 output
)
patch_embed = TokenEmbed(patch_embed_config)
Accuracy (single crop 224x224, %) on the validation set of ImageNet-1K and ImageNet-A
Backbone | ImageNet Top-1 Acc. | ImageNet-A Top-1 Acc. | #Params (M) | GFLOPs | Weight |
---|---|---|---|---|---|
SoT-Tiny | 80.3 | 21.5 | 7.7 | 2.5 | Coming soon |
SoT-Small | 82.7 | 31.8 | 26.9 | 5.8 | Coming soon |
SoT-Base | 83.5 | 34.6 | 76.8 | 14.5 | Coming soon |
Backbone | ImageNet Top-1 Acc. | ImageNet-A Top-1 Acc. | #Params (M) | GFLOPs | Weight |
---|---|---|---|---|---|
DeiT-T | 72.2 | 7.3 | 5.7 | 1.3 | model |
DeiT-T + ours | 78.6 | 17.5 | 7.0 | 2.3 | Coming soon |
DeiT-S | 79.8 | 18.9 | 22.1 | 4.6 | model |
DeiT-S + ours | 82.7 | 31.8 | 26.9 | 5.8 | Coming soon |
DeiT-B | 81.8 | 27.4 | 86.6 | 17.6 | model |
DeiT-B + ours | 82.9 | 29.1 | 94.9 | 18.2 | Coming soon |
Backbone | ImageNet Top-1 Acc. | ImageNet-A Top-1 Acc. | #Params (M) | GFLOPs | Weight |
---|---|---|---|---|---|
Swin-T | 81.3 | 21.6 | 28.3 | 4.5 | model |
Swin-T + ours | 83.0 | 33.5 | 31.6 | 6.0 | Coming soon |
Swin-B | 83.5 | 35.8 | 87.8 | 15.4 | model |
Swin-B + ours | 84.0 | 42.9 | 95.9 | 16.9 | Coming soon |
Notes:
+ours
means we adopt the proposed classification head and token embedding module upon the other architectures.- We report the accuracy training from scracth on ImageNet-1K.
Accuracy (Top-1, %) on the 4 selected tasks from General Language Understanding Evaluation (GLUE) benchmark.
- CoLA (The Corpus of Linguistic Acceptability): the task is to judge whether a English sentence is grammatical or not.
- RTE (The Recognizing Textual Entailment datasets): the task is to determine whether the given pair of sentences is entailment or not.
- MNLI (The Multi-Genre Natural Language Inference Corpus): the task is to classify the given pair of sentences from multi-source is entailment, contradiction or neutral.
- QNLI (Qusetion-answering Natural Language Inference Corpus): the task is to decide the question-answer sentence pair is entailment or not.
Backbone | CoLA | RTE | MNLI | QNLI | Weight |
---|---|---|---|---|---|
GPT | 54.32 | 63.17 | 82.10 | 86.36 | model |
GPT + ours | 57.25 | 65.35 | 82.41 | 87.13 | Coming soon |
BERT-base | 54.82 | 67.15 | 83.47 | 90.11 | model |
BERT-base + ours | 58.03 | 69.31 | 84.20 | 90.78 | Coming soon |
BERT-large | 60.63 | 73.65 | 85.90 | 91.82 | model |
BERT-large + ours | 61.82 | 75.09 | 86.46 | 92.37 | Coming soon |
SpanBERT-base | 57.48 | 73.65 | 85.53 | 92.71 | model |
SpanBERT-base + ours | 63.77 | 77.26 | 86.13 | 93.31 | Coming soon |
SpanBERT-large | 64.32 | 78.34 | 87.89 | 94.22 | model |
SpanBERT-large + ours | 65.94 | 79.79 | 88.16 | 94.49 | Coming soon |
RoBERTa-base | 61.58 | 77.60 | 87.50 | 92.70 | model |
RoBERTa-base + ours | 65.28 | 80.50 | 87.90 | 93.10 | Coming soon |
RoBERTa-large | 67.98 | 86.60 | 90.20 | 94.70 | model |
RoBERTa-large + ours | 70.90 | 88.10 | 90.50 | 95.00 | Coming soon |
We make the further analysis by visualizing the models for CV and NLP tasks, where the SoT-Tiny and BERT-base are used as the backbone for each task respectively. We compare three variants base on the SoT-Tiny and BERT-base as follows:
- ClassT: only classification token is used for classifier
- WordT: only word tokens are used for classifier
- ClassT+WordT: both classification token and word tokens are used for classifier based on the sum scheme.
√: correct prediction; ✗: incorrect prediction
We can see the ClassT is more suitable for classifying the categories associated with the backgrounds and the whole context. The WordT performs classfication primarily based on some local discriminative regions. Our ClassT+WordT can make fully use of merits of both word tokens and classfication token, which can focus on the most important regions for better classficaiton by exploiting both local and global information.
We selected some examples from CoLA task, which aims to judge whether an English sentence is grammatical or not. The greener background color denotes stronger impact of the word to the classification, while the bluer implies weaker one. We can see the proposed ClassT+WordT can highlight all important words in sentence while the others two fails, which can help to boost the performance of classification.
pytorch: https://github.com/pytorch/pytorch
timm: https://github.com/rwightman/pytorch-image-models
T2T-ViT: https://github.com/yitu-opensource/T2T-ViT
If you have any questions or suggestions, please contact me