Skip to content

Latest commit

 

History

History
88 lines (71 loc) · 4.63 KB

training.md

File metadata and controls

88 lines (71 loc) · 4.63 KB

Preparing the datasets

To download the wwPDB dataset and proprecessed training data, you need at least 1T disk space.

Use the following command to download the preprocessed wwpdb training databases:

wget -P /af3-dev/release_data/ https://af3-dev.tos-cn-beijing.volces.com/release_data.tar.gz
tar -xzvf /af3-dev/release_data/release_data.tar.gz -C /af3-dev/release_data/
rm /af3-dev/release_data/release_data.tar.gz

The data should be placed in the /af3-dev/release_data/ directory. You can also download it to a different directory, but remember to modify the DATA_ROOT_DIR in configs/configs_data.py correspondingly. Data hierarchy after extraction is as follows:

├── components.v20240608.cif [408M] # ccd source file
├── components.v20240608.cif.rdkit_mol.pkl [121M] # rdkit Mol object generated by ccd source file
├── indices [33M] # chain or interface entries
├── mmcif [283G]  # raw mmcif data
├── mmcif_bioassembly [36G] # preprocessed wwPDB structural data
├── mmcif_msa [450G] # msa files
├── posebusters_bioassembly [42M] # preprocessed posebusters structural data
├── posebusters_mmcif [361M] # raw mmcif data
├── recentPDB_bioassembly [1.5G] # preprocessed recentPDB structural data
└── seq_to_pdb_index.json [45M] # sequence to pdb id mapping file

Data processing scripts have also been released. you can refer to prepare_training_data.md for generating {dataset}_bioassembly and indices. And you can refer to msa_pipeline.md for pipelines to get mmcif_msa and seq_to_pdb_index.json.

Training demo

After the installation and data preparations, you can run the following command to train the model from scratch:

bash train_demo.sh 

Key arguments in this scripts are explained as follows:

  • dtype: data type used in training. Valid options include "bf16" and "fp32".

    • --dtype fp32: the model will be trained in full FP32 precision.
    • --dtype bf16: the model will be trained in BF16 Mixed precision, by default, the SampleDiffusion,ConfidenceHead, Mini-rollout and Loss part will still be training in FP32 precision. if you want to train and infer the model in full BF16 Mixed precision, pass the following arguments to the train_demo.sh:
      --skip_amp.sample_diffusion_training false \
      --skip_amp.confidence_head false \
      --skip_amp.sample_diffusion false \
      --skip_amp.loss false \
  • ema_decay: the decay rate of the EMA, default is 0.999.

  • sample_diffusion.N_step: during evalutaion, the number of steps for the diffusion process is reduced to 20 to improve efficiency.

  • data.train_sets/data.test_sets: the datasets used for training and evaluation. If there are multiple datasets, separate them with commas.

  • Some settings follow those in the AlphaFold 3 paper, The table in model_performance.md shows the training settings and memory usages for different training stages.

  • In this version, we do not use the template and RNA MSA feature for training. As the default settings in configs/configs_base.py and configs/configs_data.py:

    --model.template_embedder.n_blocks 0 \
    --data.msa.enable_rna_msa false \

    This will be considered in our future work.

  • The model also supports distributed training with PyTorch’s torchrun. For example, if you’re running distributed training on a single node with 4 GPUs, you can use:

    torchrun --nproc_per_node=4 runner/train.py

    You can also pass other arguments with --<ARGS_KEY> <ARGS_VALUE> as you want.

If you want to speed up training, see setting up kernels documentation .

Finetune demo

If you want to fine-tune the model on a specific subset, such as an antibody dataset, you only need to provide a PDB list file and load the pretrained weights as finetune_demo.sh shows:

# wget -P /af3-dev/release_model/ https://af3-dev.tos-cn-beijing.volces.com/release_model/model_v0.2.0.pt
checkpoint_path="/af3-dev/release_model/model_v0.2.0.pt"
...

--load_checkpoint_path ${checkpoint_path} \
--load_checkpoint_ema_path ${checkpoint_path} \
--data.weightedPDB_before2109_wopb_nometalc_0925.base_info.pdb_list examples/subset.txt \

, where the subset.txt is a file containing the PDB IDs like:

6hvq
5mqc
5zin
3ew0
5akv