It's possible to train a RAPPPID model using the train.py
utility.
You'll first need a dataset in a format that RAPPPID understands. There are two options.
-
If you wish to you use the datasets from the Szymborski & Emad manuscript, you can read the "Szymborski & Emad Datasets" heading in the data.md docs.
-
To prepare a new dataset, read the "Preparing RAPPPID Datasets" header in data.md.
First begin by generating a SentencePiece vocabulary using rapppid/train_seg.py
.
You can run this script from the CLI.
Usage: train_seg.py [OPTIONS] SEQ_PATH TRAIN_PATH
Some more details:
SEQ_PATH
: is the location of the sequences Pickle file (see data.md)TRAIN_PATH
: is the location of the training pairs Pickle file (see data.md)seed: int
: Random seed for determinism.vocab_size: int
: The size of the vocabulary to be generated.- recommended value: a value of
250
was used in the paper.
- recommended value: a value of
This script makes sure that the SentencePiece model is only trained on sequence present in the training dataset to ensure no data leakage.
To train, validate, and test the model, run the train.py
python file in the
rapppid
folder. train.py
takes the following positional arguments:
batch_size: int
The training mini-batch size- recommended value: a value of
80
was used in the paper on a RTX 2080, with 32 CPU cores clocked at 2.2GHz.
- recommended value: a value of
train_path: Path
The path to the training files. RAPPPID training files can be found in thedata/rapppid
folderval_path: Path
The path to the validation files. RAPPPID training files can be found in thedata/rapppid
foldertest_path: Path
The path to the testing files. RAPPPID training files can be found in thedata/rapppid
folderseqs_path: Path
The path to the file containing protein sequences. RAPPPID protein sequences can be found in thedata/rapppid
foldertrunc_len: int
Sequences longer than thetrunc_len
will be truncated to this length.- recommended value: A value of
1500
was used in the paper, but values as large as3000
and as small as1000
have been used during development. A value of3000
means almost all proteins won't be truncated, while1500
still only truncates a small proportion of proteins. Larger values lead to vanishing gradients, so if training is unstable, this is a very good parameter to look at.
- recommended value: A value of
embedding_size: int
The size of the token embeddings to use. This also dictates the number of parameters in the LSTM cells.- recommended value A value of
64
was used in the paper.32
has also worked well.
- recommended value A value of
num_epochs: int
The maximum number of epochs to run. Testing will be run on the epoch with the lowest validation loss.- recommended value
train.py
will update the model checkpoints when the validation loss reaches a new low. So in the paper, we set the number of epochs to100
, and reported the test metrics of the model with the lowest val loss (this is done automatically bytrain.py
).
- recommended value
lstm_dropout_rate: float
The rate at which connections are dropped in the LSTM layers (aka DropConnect)- recommended value See hyperparams.md. We tuned this hyperparameter and recommend you do so on new datasets as well.
classhead_dropout_rate: float
The rate at which activates are dropped at the fully-connected classifier (aka Dropout)- recommended value See hyperparams.md. We tuned this hyperparameter and recommend you do so on new datasets as well.
rnn_num_layers: int
Number of LSTM layers to use- recommended value See hyperparams.md. We tuned this hyperparameter and recommend you do so on new datasets as well.
class_head_name: str
The kind of classifier head to use.- recommended value Use
concat
to replicate the RAPPPID manuscript. - Update: We've found using
mult
provides similar performance, reduces the number of parameters, and more deterministic.
- recommended value Use
variational_dropout: bool
Whether the DropConnect applied on the LSTM layers is done using variational dropout or not.- recommended value
False
.
- recommended value
lr_scaing: bool
Whether or not to scale learning rate with sequence length.- recommended value Set to
False
to replicate RAPPPID manuscript, other values are poorly supported.
- recommended value Set to
model_file: str
Path to the SentencePiece model file generated in Step 2log_path: Path
Where to store logging files (saved weights, tensorboard files, hyper-parameters)- n.b.: the directory in
log_path
must have the following directories below is:args
: Theargs
folder will hold (in JSON files) all the hyperparameters, as well as the training, validation, and testing metrics. The most useful information is usually here.chkpts
: Pytorch Lightning Model Checkpoints are stored here. They hold both hyperparameters as well as model weights.tb_logs
: Holds the tensorboard logsonnx
: ONNX files are meant to be saved here, but serialization usually fails, so best to use the weights fromchkpts
.charts
: Quick ROC and PR charts for the testing dataset are generated for each trained model.
- n.b.: the directory in
vocab_size: int
The size of the sentencepiece vocabulary. Use the value set in Step 2.embedding_droprate: float
The rate at which embeddings are dropped (aka Embedding Dropout)- recommended value See hyperparams.md. We tuned this hyperparameter and recommend you do so on new datasets as well.
transfer_path: str
If you wish to load weights that were pre-trained, include the checkpoint file from/logs/chkpts/yourmodelname.chkpt
optimizer_type: str
The optimizer to use. Valid values areranger21
andadam
.- recommended value We use
ranger21
in the manuscript, butadam
also works well.
- recommended value We use
swa: bool
Enable Stochastic Weight Averaging.- recommended value
True
in the manuscript.
- recommended value
seed: int
Seed to use for training.
The path in log_path
contains model checkpoints, as well as logs, and evaluation metrics. You can monitor loss and various metrics live on tensorboard as well.
Training, validation, and testing metrics as well as hyperparameters are all present in the args
folder in JSON format. Simply look for the file with your model name (unix timestamp + two random words).