Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose additional Chai-1 parameters in the pipeline #11

Merged
merged 9 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ Special thanks to the following for their contributions to the release:
- [PR #5](https://github.com/seqeralabs/nf-chai/pull/5) - Add chai-1 functionality to the pipeline
- [PR #8](https://github.com/seqeralabs/nf-chai/pull/8) - Add `--weights_dir` parameter to provide pre-downloaded weights to Chai-1
- [PR #9](https://github.com/seqeralabs/nf-chai/pull/9) - Add reports file for pipeline outputs
- [PR #11](https://github.com/seqeralabs/nf-chai/pull/11) - Add additional parameters from chai-1 to nextflow config
33 changes: 29 additions & 4 deletions bin/run_chai_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,31 @@ def main():
type=Path,
help="Path to the input FASTA file."
)
# Add optional arguments with current defaults
parser.add_argument(
"--num-trunk-recycles",
type=int,
default=3,
help="Number of trunk recycles (default: 3)"
)
parser.add_argument(
"--num-diffn-timesteps",
type=int,
default=200,
help="Number of diffusion timesteps (default: 200)"
)
parser.add_argument(
"--seed",
type=int,
default=42,
help="Random seed for reproducibility (default: 42)"
)
parser.add_argument(
"--use-esm-embeddings",
action="store_true",
default=True,
help="Use ESM embeddings (enabled by default)"
)

# Parse arguments
args = parser.parse_args()
Expand All @@ -40,11 +65,11 @@ def main():
run_inference(
fasta_file=args.fasta_file,
output_dir=args.output_dir,
num_trunk_recycles=3,
num_diffn_timesteps=200,
seed=42,
num_trunk_recycles=args.num_trunk_recycles,
num_diffn_timesteps=args.num_diffn_timesteps,
seed=args.seed,
device=device,
use_esm_embeddings=True,
use_esm_embeddings=args.use_esm_embeddings,
)

if __name__ == "__main__":
Expand Down
19 changes: 13 additions & 6 deletions modules/local/chai_1/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,32 @@ process CHAI_1 {
input:
tuple val(meta), path(fasta)
path weights_dir
val num_trunk_recycles
val num_diffn_timesteps
drpatelh marked this conversation as resolved.
Show resolved Hide resolved
val seed
val use_esm_embeddings

output:
tuple val(meta), path("${meta.id}/*.cif"), emit: structures
tuple val(meta), path("${meta.id}/*.npz"), emit: arrays
path "versions.yml" , emit: versions

script:
def downloads_dir = weights_dir ?: './downloads'
def args = task.ext.args ?: ''
drpatelh marked this conversation as resolved.
Show resolved Hide resolved
def esm_flag = use_esm_embeddings ? '--use-esm-embeddings' : ''
"""
CHAI_DOWNLOADS_DIR=$downloads_dir \\
run_chai_1.py \\
--output-dir ${meta.id} \\
--fasta-file ${fasta}
--fasta-file ${fasta} \\
--output-dir . \\
--num-trunk-recycles ${num_trunk_recycles} \\
--num-diffn-timesteps ${num_diffn_timesteps} \\
drpatelh marked this conversation as resolved.
Show resolved Hide resolved
--seed ${seed} \\
${esm_flag} \\
$args

cat <<-END_VERSIONS > versions.yml
"${task.process}":
python: \$(python --version | sed 's/Python //g')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove this @FloWuenne ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't remember removing this, must have been accidental, sorry!

chai_lab: \$(python -c "import chai_lab; print(chai_lab.__version__)")
torch: \$(python -c "import torch; print(torch.__version__)")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove this @FloWuenne ?

END_VERSIONS
"""

Expand Down
5 changes: 5 additions & 0 deletions nextflow.config
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ params {
input = null
weights_dir = null
use_gpus = false
num_trunk_recycles = 3
num_diffn_timesteps = 200
seed = 42
use_esm_embeddings = true
drpatelh marked this conversation as resolved.
Show resolved Hide resolved

// Boilerplate options
outdir = null
Expand All @@ -26,6 +30,7 @@ params {

// Schema validation default options
validate_params = true

}

// Default publishing settings for all processes
Expand Down
31 changes: 30 additions & 1 deletion nextflow_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
"type": "string",
"format": "directory-path",
"description": "The output directory where the results will be saved. You have to use absolute paths to storage on Cloud infrastructure.",
"fa_icon": "fas fa-folder-open"
"fa_icon": "fas fa-folder-open",
"help_text": ""
drpatelh marked this conversation as resolved.
Show resolved Hide resolved
},
"weights_dir": {
"type": "string",
Expand All @@ -39,6 +40,34 @@
"type": "boolean",
"description": "Run compatible tasks on GPUs rather than CPUs (default).",
"fa_icon": "fas fa-microchip"
},
"num_trunk_recycles": {
"type": "integer",
"default": 3,
"fa_icon": "fas fa-recycle",
drpatelh marked this conversation as resolved.
Show resolved Hide resolved
"hidden": true
},
"num_diffn_timesteps": {
drpatelh marked this conversation as resolved.
Show resolved Hide resolved
"type": "integer",
"default": 200,
"fa_icon": "fas fa-shoe-prints",
"hidden": true,
"description": "Number of diffusion steps to use."
},
"seed": {
"type": "integer",
"default": 42,
"fa_icon": "fas fa-seedling",
"hidden": true,
"description": "Random seed to be used for Chai-1 calculations",
"help_text": ""
drpatelh marked this conversation as resolved.
Show resolved Hide resolved
},
"use_esm_embeddings": {
"type": "boolean",
"default": true,
"fa_icon": "fas fa-stamp",
"hidden": true,
"description": "Use user-provided esm model embeddings?"
drpatelh marked this conversation as resolved.
Show resolved Hide resolved
}
}
},
Expand Down
10 changes: 7 additions & 3 deletions workflows/nf_chai/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ include { CHAI_1 } from '../../modules/local/chai_1'
workflow NF_CHAI {

take:
fasta_file // string: path to fasta file read provided via --input parameter
weights_dir // string: path to model directory read provided via --weights_directory parameter
fasta_file // string: path to fasta file read provided via --input parameter
weights_dir // string: path to model directory read provided via --weights_directory parameter
drpatelh marked this conversation as resolved.
Show resolved Hide resolved

main:

Expand All @@ -34,7 +34,11 @@ workflow NF_CHAI {
// Run structure prediction with Chai-1
CHAI_1 (
ch_fasta,
weights_dir ? Channel.fromPath(weights_dir) : []
weights_dir ? Channel.fromPath(weights_dir) : [],
params.num_trunk_recycles,
params.num_diffn_timesteps,
params.seed,
params.use_esm_embeddings
drpatelh marked this conversation as resolved.
Show resolved Hide resolved
)
ch_versions = ch_versions.mix(CHAI_1.out.versions)

Expand Down