This Repo is the official implementatation of DNABERT_S: Learning Species-Aware DNA Embedding with Genome Foundation Models.
- 1. Introduction
- 2. Model and Data
- 3. Setup Environment
- 4. Quick Start
- 5. Training
- 6. Evaluation
- 7. Citation
DNABERT-S is a foundation model based on DNABERT-2 specifically designed for generating DNA embedding that naturally clusters and segregates genome of different species in the embedding space, which can greatly benefit a wide range of genome applications, including species classification/identification, metagenomics binning, and understanding evolutionary relationships.
Results on species clustering.
The pre-trained models is available at Huggingface as zhihan1996/DNABERT-S
.
To download the model from command line:
# command line
gdown 1ejNOMXdycorDzphLT6jnfGIPUxi6fO0g # pip install gdown
unzip dnabert-s_train.zip # unzip the data
The training data of DNABERT-S is available at
gdown 1p59ch_MO-9DXh3LUIvorllPJGLEAwsUp # pip install gdown
unzip dnabert-s_train.zip # unzip the data
The evaluation data is available at
gdown 1I44T2alXrtXPZrhkuca6QP3tFHxDW98c # pip install gdown
unzip dnabert-s_eval.zip # unzip the data
conda create -n DNABERT_S python=3.9
conda activate DNABERT_S
pip install -r requirements.txt
pip uninstall triton # this can lead to errors in GPUs other than A100
Our model is easy to use with the transformers package.
To load the model from huggingface:
import torch
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNABERT-S", trust_remote_code=True)
model = AutoModel.from_pretrained("zhihan1996/DNABERT-S", trust_remote_code=True)
To calculate the embedding of a dna sequence
dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]
hidden_states = model(inputs)[0] # [1, sequence_length, 768]
# embedding with mean pooling
embedding_mean = torch.mean(hidden_states[0], dim=0)
print(embedding_mean.shape) # expect to be 768
Our code base expects pairs of DNA sequencesfor pre-training. We expect the training data to be a csv file with two columns and no header. Each row contains one pair of DNA sequences that you want to model to generate similar embedding for them. See data/debug_train.csv for an example.
Important arguments:
- resdir: dictionary to save model parameters
- datapath: dictionary of data
- train_dataname: the name of the training data file (e.g., "a.csv")
- val_dataname: the name of the validating data file (e.g., "a.csv")
- max_length: set it as 0.2 * DNA_length (e.g., 200 for 1000-bp DNA)
- train_batch_size: batch size for training data, change it to fit your GPU RAM
- con_method: contrastive learning method, including "same species", "dropout", "double_strand", "mutate"
- mix: whether use i-Mix method
- mix_layer_num: which layer to perform i-Mix, if the value is -1, it means manifold i-Mix
- curriculum: whether use curriculum learning
- Other arguments can also be adjusted.
For our curriculum contrastive learning method, you can use:
cd pretrain
export PATH_TO_DATA_DICT=/path/to/data
export TRAIN_FILE=debug_train.csv # use this for debug, for real training, please use train_2m.csv
python main.py \
--resdir ./results/ \
--datapath ${PATH_TO_DATA_DICT} \
--train_dataname ${TRAIN_FILE} \
--val_dataname val_48k.csv \
--seed 1 \
--logging_step 10000 \
--logging_num 12 \
--max_length 2000 \
--train_batch_size 48 \
--val_batch_size 360 \
--lr 3e-06 \
--lr_scale 100 \
--epochs 3 \
--feat_dim 128 \
--temperature 0.05 \
--con_method same_species \
--mix \
--mix_alpha 1.0 \
--mix_layer_num -1 \
--curriculum
This training scripts expect 8 A100 80GB GPUs. If you are using other types of devices, please change the train_batch_size and max_length accordingly.
After model training, you will find the trained model at ./pretrain/results/$file_name
The file_name is automatically set based on the hyperparameters, and the code regularly save checkpoint.
It should be something like ./results/contrastive.HardNeg.epoch2.debug_train.csv.lr3e-06.lrscale100.bs48.maxlength200.tmp0.05.decay1.seed1.turn1/100
The best model after validating is saved in ./pretrain/results/$file_name/best/
Scripts for other experiments are all in ./pretrain/results
cd evaluate
gdown 1ejNOMXdycorDzphLT6jnfGIPUxi6fO0g
unzip DNABERT-S.zip
export MODEL_DIR=/path/to/DNABERT-S (e.g., /root/Downloads/DNABERT-S)
Copy the necessary files to the folder where the model is saved. This is a bug in Huggingface Transformers package. Sometimes the model file such as bert_layer.py
are not automatically saved to the model directory together with the model weights. So we manually do it.
export MODEL_DIR=/path/to/the/trained/model # (e.g., /root/ICML2024/train/pretrain/results/epoch3.debug_train.csv.lr3e-06.lrscale100.bs24.maxlength2000.tmp0.05.seed1.con_methodsame_species.mixTrue.mix_layer_num-1.curriculumTrue/0)
cp model_codes/* ${MODEL_DIR}
export DATA_DIR=/path/to/the/unziped/folders
# evaluate the trained model
python eval_clustering_classification.py --test_model_dir ${MODEL_DIR} --data_dir ${DATA_DIR} --model_list "test"
# evaluate baselines (e.g., TNF and DNABERT-2)
python eval_clustering_classification.py --data_dir ${DATA_DIR} --model_list "tnf, dnabert2"
export DATA_DIR=/path/to/the/unziped/folders
export MODEL_DIR=/path/to/the/trained/model
# evaluate the trained model
python eval_binning.py --test_model_dir ${MODEL_DIR} --data_dir ${DATA_DIR} --model_list "test"
# evaluate baselines (e.g., TNF and DNABERT-2)
python eval_binning.py --data_dir ${DATA_DIR} --model_list "tnf, dnabert2"
If you have any question regarding our paper or codes, please feel free to start an issue or email Zhihan Zhou ([email protected]).
If you use DNABERT-S in your work, please consider cite our papers:
DNABERT-S
@misc{zhou2024dnaberts,
title={DNABERT-S: Learning Species-Aware DNA Embedding with Genome Foundation Models},
author={Zhihan Zhou and Winmin Wu and Harrison Ho and Jiayi Wang and Lizhen Shi and Ramana V Davuluri and Zhong Wang and Han Liu},
year={2024},
eprint={2402.08777},
archivePrefix={arXiv},
primaryClass={q-bio.GN}
}
DNABERT-2
@misc{zhou2023dnabert2,
title={DNABERT-2: Efficient Foundation Model and Benchmark For Multi-Species Genome},
author={Zhihan Zhou and Yanrong Ji and Weijian Li and Pratik Dutta and Ramana Davuluri and Han Liu},
year={2023},
eprint={2306.15006},
archivePrefix={arXiv},
primaryClass={q-bio.GN}
}
DNABERT
@article{ji2021dnabert,
author = {Ji, Yanrong and Zhou, Zhihan and Liu, Han and Davuluri, Ramana V},
title = "{DNABERT: pre-trained Bidirectional Encoder Representations from Transformers model for DNA-language in genome}",
journal = {Bioinformatics},
volume = {37},
number = {15},
pages = {2112-2120},
year = {2021},
month = {02},
issn = {1367-4803},
doi = {10.1093/bioinformatics/btab083},
url = {https://doi.org/10.1093/bioinformatics/btab083},
eprint = {https://academic.oup.com/bioinformatics/article-pdf/37/15/2112/50578892/btab083.pdf},
}