diff --git a/fusedrug/data/protein/antibody/antibody.py b/fusedrug/data/protein/antibody/antibody.py index 3736ea03..fd837d7d 100644 --- a/fusedrug/data/protein/antibody/antibody.py +++ b/fusedrug/data/protein/antibody/antibody.py @@ -1,4 +1,4 @@ -from typing import List, Dict +from typing import List, Dict, Optional from fusedrug.data.protein.structure.sabdab import load_sabdab_dataframe import pandas as pd from collections import namedtuple @@ -33,12 +33,16 @@ def get_antibody_regions(sequence: str, scheme: str = "chothia") -> Dict[str, st return ans -def get_antibodies_info_from_sabdab(antibodies_pdb_ids: List[str]) -> List[Antibody]: +def get_antibodies_info_from_sabdab( + antibodies_pdb_ids: Optional[List[str]] = None, +) -> List[Antibody]: """ Collects information on all provided antibodies_pdb_ids based on SabDab DB. """ sabdab_df = load_sabdab_dataframe() + if antibodies_pdb_ids is None: + antibodies_pdb_ids = sabdab_df.pdb.unique().tolist() antibodies = [] for pdb_id in antibodies_pdb_ids: found = sabdab_df[sabdab_df.pdb == pdb_id] diff --git a/fusedrug/data/protein/sequence/official_pdb_fasta.py b/fusedrug/data/protein/sequence/official_pdb_fasta.py new file mode 100644 index 00000000..bfdd7e5f --- /dev/null +++ b/fusedrug/data/protein/sequence/official_pdb_fasta.py @@ -0,0 +1,38 @@ +from io import StringIO +from Bio import SeqIO +from urllib.request import urlopen +from typing import Dict + + +def get_fasta_from_rcsb(pdb_id: str) -> Dict: # TODO: consider adding caching + """ + Given some pdb_id, (like "7vux"), we will retrieve its fasta file from rcsb database and return it as a dict {chain: sequence}. + """ + fasta_data = ( + urlopen(f"https://www.rcsb.org/fasta/entry/{pdb_id.upper()}") + .read() + .decode("utf-8") + ) + fasta_file_handle = StringIO(fasta_data) + chains_full_seq = SeqIO.to_dict( + SeqIO.parse(fasta_file_handle, "fasta"), + key_function=lambda rec: _description_to_author_chain_id(rec.description), + ) + chains_full_seq = {k: str(d.seq) for (k, d) in chains_full_seq.items()} + return chains_full_seq + + +def _description_to_author_chain_id(description: str) -> str: + loc = description.find(" ") + assert loc >= 0 + description = description[loc + 1 :] + loc = description.find(",") + if loc >= 0: + description = description[:loc] + + token = "auth " + loc = description.find(token) + if loc >= 0: + return description[loc + len(token)] + + return description[0] diff --git a/fusedrug/data/protein/structure/align_multiple_antibodies.py b/fusedrug/data/protein/structure/align_multiple_antibodies.py new file mode 100644 index 00000000..15a18db1 --- /dev/null +++ b/fusedrug/data/protein/structure/align_multiple_antibodies.py @@ -0,0 +1,104 @@ +from os.path import join, dirname +from fusedrug.data.protein.structure.flexible_align_chains_structure import ( + flexible_align_chains_structure, +) +from jsonargparse import CLI +import pandas as pd +from typing import Optional +import numpy as np + + +def main( + input_excel_filename: str, + unique_id_column: str, + reference_heavy_chain_pdb_filename_column: str, + reference_heavy_chain_id_column: str, + heavy_chain_pdb_filename_column: str, + heavy_chain_id_column: str, + light_chain_pdb_filename_column: str, + light_chain_id_column: str, + aligned_using_only_heavy_chain: bool = True, + output_structure_file_prefix: str = "aligned_antibody_", + output_excel_filename: Optional[str] = None, + output_excel_aligned_heavy_chain_pdb_filename_column: str = "aligned_heavy_chain_pdb_filename", + output_excel_aligned_heavy_chain_id_column: str = None, + output_excel_aligned_light_chain_pdb_filename_column: str = "aligned_light_chain_pdb_filename", + output_excel_aligned_light_chain_id_column: str = None, +) -> pd.DataFrame: + + assert ( + aligned_using_only_heavy_chain + ), "only supporting aligned_using_only_heavy_chain=True for now. Note that flexible_align_chains_structure is indeed flexible enough to support this, if needed." + + df = pd.read_excel(input_excel_filename, index_col=unique_id_column) + + # base = '/dccstor/dsa-ab-cli-val-0/2024_feb_delivery/top_100_with_indels/antibody_dimers_af2_predicted_structure' + # reference_heavy_chain = '/dccstor/dsa-ab-cli-val-0/targets/PD-1/7VUX/relaxed_complex/PD1_7VUX_H_eq.pdb' + + df[output_excel_aligned_heavy_chain_pdb_filename_column] = np.nan + df[output_excel_aligned_heavy_chain_id_column] = np.nan + df[output_excel_aligned_light_chain_pdb_filename_column] = np.nan + df[output_excel_aligned_light_chain_id_column] = np.nan + + for index, row in df.iterrows(): + reference_heavy_chain_pdb_filename = row[ + reference_heavy_chain_pdb_filename_column + ] + reference_heavy_chain_id = row[reference_heavy_chain_id_column] + # reference_light_chain_id = row[reference_light_chain_id_column] + + # heavy chain + heavy_chain_pdb_filename = row[heavy_chain_pdb_filename_column] + heavy_chain_id = row[heavy_chain_id_column] # 'A' + # light chain + light_chain_pdb_filename = row[light_chain_pdb_filename_column] + light_chain_id = row[light_chain_id_column] # 'B' + + output_aligned_fn = join( + dirname(heavy_chain_pdb_filename), output_structure_file_prefix + ) + + if not isinstance(reference_heavy_chain_pdb_filename, str): + print( + f"ERROR: expected reference_heavy_chain_pdb_filename to be string, but got {reference_heavy_chain_pdb_filename} of type {type(reference_heavy_chain_pdb_filename)}" + ) + continue + + if len(reference_heavy_chain_pdb_filename) < 2: + print( + f'ERROR: expected reference_heavy_chain_pdb_filename to be string, but got a suspicious empty or extremely short one: "{reference_heavy_chain_pdb_filename}"' + ) + continue + + flexible_align_chains_structure( + dynamic_ordered_chains=[(heavy_chain_pdb_filename, heavy_chain_id)], + apply_rigid_transformation_to_dynamic_chain_ids=[ + (heavy_chain_pdb_filename, heavy_chain_id), + (light_chain_pdb_filename, light_chain_id), + ], + static_ordered_chains=[ + (reference_heavy_chain_pdb_filename, reference_heavy_chain_id) + ], + output_pdb_filename_extentionless=output_aligned_fn, + ) + + # heavy chain + df.loc[index, output_excel_aligned_heavy_chain_pdb_filename_column] = ( + output_aligned_fn + f"_chain_{heavy_chain_id}.pdb" + ) + df.loc[index, output_excel_aligned_heavy_chain_id_column] = heavy_chain_id + # light chain + df.loc[index, output_excel_aligned_light_chain_pdb_filename_column] = ( + output_aligned_fn + f"_chain_{light_chain_id}.pdb" + ) + df.loc[index, output_excel_aligned_light_chain_id_column] = light_chain_id + + if output_excel_filename is not None: + df.to_excel(output_excel_filename) + print("saved ", output_excel_filename) + + return df + + +if __name__ == "__main__": + CLI(main) diff --git a/fusedrug/data/protein/structure/extract_chains_to_pdbs.py b/fusedrug/data/protein/structure/extract_chains_to_pdbs.py new file mode 100644 index 00000000..07bfaf5b --- /dev/null +++ b/fusedrug/data/protein/structure/extract_chains_to_pdbs.py @@ -0,0 +1,71 @@ +from jsonargparse import CLI +from fusedrug.data.protein.structure.structure_io import ( + load_pdb_chain_features, + save_structure_file, +) +from typing import Optional + + +def main( + *, + input_pdb_path: str, + orig_name_chains_to_extract: str, + output_pdb_path_extensionless: str, + output_chain_ids_to_extract: Optional[str] = None, +) -> None: + """ + + Takes an input PDB files and splits it into separate files, one per describe chain, allowing to rename the chains if desired + + Args: + input_pdb_path: + input_chain_ids_to_extract: '_' separated chain ids + output_chain_ids_to_extract: '_' separated chain ids + if not provided, will keep original chain ids + + """ + + orig_name_chains_to_extract = orig_name_chains_to_extract.split("_") + if output_chain_ids_to_extract is None: + output_chain_ids_to_extract = orig_name_chains_to_extract.split("_") + else: + output_chain_ids_to_extract = output_chain_ids_to_extract.split("_") + + assert len(orig_name_chains_to_extract) > 0 + assert len(orig_name_chains_to_extract) == len(output_chain_ids_to_extract) + assert len(orig_name_chains_to_extract[0]) == 1 + + loaded_chains = {} + for orig_chain_id in orig_name_chains_to_extract: + loaded_chains[orig_chain_id] = load_pdb_chain_features( + input_pdb_path, orig_chain_id + ) + + mapping = dict(zip(orig_name_chains_to_extract, output_chain_ids_to_extract)) + + loaded_chains_mapped = { + mapping[chain_id]: data for (chain_id, data) in loaded_chains.items() + } + + save_structure_file( + output_filename_extensionless=output_pdb_path_extensionless, + pdb_id="unknown", + chain_to_atom14={ + chain_id: data["atom14_gt_positions"] + for (chain_id, data) in loaded_chains_mapped.items() + }, + chain_to_aa_str_seq={ + chain_id: data["aasequence_str"] + for (chain_id, data) in loaded_chains_mapped.items() + }, + chain_to_aa_index_seq={ + chain_id: data["aatype"] + for (chain_id, data) in loaded_chains_mapped.items() + }, + save_cif=False, + mask=None, # TODO: check + ) + + +if __name__ == "__main__": + CLI(main) diff --git a/fusedrug/data/protein/structure/flexible_align_chains_structure.py b/fusedrug/data/protein/structure/flexible_align_chains_structure.py new file mode 100644 index 00000000..3abb62ce --- /dev/null +++ b/fusedrug/data/protein/structure/flexible_align_chains_structure.py @@ -0,0 +1,275 @@ +from jsonargparse import CLI +from typing import List, Union, Dict, Tuple, Optional +from Bio import Align +from tiny_openfold.utils.superimposition import superimpose + +# from fusedrug.data.protein.structure.protein_complex import ProteinComplex +from fusedrug.data.protein.structure.structure_io import ( + load_pdb_chain_features, + protein_utils, + # flexible_save_pdb_file, + save_structure_file, +) +import numpy as np +from warnings import warn + + +def flexible_align_chains_structure( + dynamic_ordered_chains: Union[List[Tuple], str], + apply_rigid_transformation_to_dynamic_chain_ids: Union[List[Tuple], str], + static_ordered_chains: Union[List[Tuple], str], + output_pdb_filename_extentionless: str, + minimal_matching_sequence_level_chunk: Optional[int] = 8, + backbone_only_based: bool = False, + ###chain_id_type:str = "author_assigned", +) -> None: + """ + Finds and applies a rigid transformation to align between chains (or sets of chains) + Searches first for sequence level alignment, and then uses the matching subset to find the rigid transformation + + IMPORTANT: if you provide multiple chains, the order matters and should be consistent with the order in static_ordered_chains + otherwise you might get nonsensical alignment ! + + Args: + + dynamic_ordered_chains: the chains from `pdb_dynamic` that we want to move. + either a list, for example: [ ('7vux', 'H'), ('/some/path/blah.pdb','N'), ...] #each tuple is [pdb id or filename, chain_id] + or a string, for example: "7vux^H@/some/path/blah.pdb^N #^ seprates between the different tuples and ^ separates between the tuple elements + IMPORTANT: if you provide multiple, the order matters and should be consistent with the order in static_ordered_chains otherwise you might get nonsensical alignment ! + + apply_rigid_transformation_to_dynamic_chain_ids: + either a list, for example: [ ('7vux', 'H'), ('/some/path/blah.pdb','N'), ...] #each tuple is [pdb id or filename, chain_id] + or a string, for example: "7vux^H@/some/path/blah.pdb^N #^ seprates between the different tuples and ^ separates between the tuple elements + + the found transformation will be applied to these changed, and these chains will be stored in the location that `output_pdb_filename` defines + It can be identical to dynamic_ordered_chains, or it can be different. + A use case in which making it different can make sense is to align heavy+light chains of a candidate antibody to the heavy chain of a reference + + + static_ordered_chains: the chains from `pdb_static` that we want to align the dynamic part to. + IMPORTANT: if you provide multiple, the order matters and should be consistent with the order in dynamic_ordered_chains otherwise you might get nonsensical alignment ! + either a list, for example: [ ('7vux', 'H'), ('/some/path/blah.pdb','N'), ...] #each tuple is [pdb id or filename, chain_id] + or a string, for example: "7vux^H@/some/path/blah.pdb^N #^ seprates between the different tuples and ^ separates between the tuple elements + + output_pdb_filename: the chains from pdb_dynamic that are selected and moved will be saved into this pdb file + + minimal_matching_sequence_level_chunk: the minimal size in which a chunk of matching aligned sequence will be used for the 3d alignment. + The motivation for this is to avoid "nonsense" matches scattered all over the sequence, resulting in (very) suboptimal alignment + + """ + + dynamic_ordered_chains = _to_list(dynamic_ordered_chains) + apply_rigid_transformation_to_dynamic_chain_ids = _to_list( + apply_rigid_transformation_to_dynamic_chain_ids + ) + static_ordered_chains = _to_list(static_ordered_chains) + + dynamic_chains: Dict[str, protein_utils.Protein] = {} + for pdb_file, chain_id in dynamic_ordered_chains: + dynamic_chains[chain_id] = load_pdb_chain_features(pdb_file, chain_id) + + apply_rigid_on_dynamic_chains: Dict[str, protein_utils.Protein] = {} + for pdb_file, chain_id in apply_rigid_transformation_to_dynamic_chain_ids: + apply_rigid_on_dynamic_chains[chain_id] = load_pdb_chain_features( + pdb_file, chain_id + ) + + static_chains: Dict[str, protein_utils.Protein] = {} + for pdb_file, chain_id in static_ordered_chains: + static_chains[chain_id] = load_pdb_chain_features(pdb_file, chain_id) + + attributes = [ + "atom14_gt_positions", + "atom14_gt_exists", + "aasequence_str", + "aatype", + # "residue_index", + ] + + # concatanate + dynamic_concat = { + attribute: _concat_elements_from_dict(dynamic_chains, attribute) + for attribute in attributes + } + + static_concat = { + attribute: _concat_elements_from_dict(static_chains, attribute) + for attribute in attributes + } + + # calculate alignment in sequence space + dynamic_indices, static_indices = get_alignment_indices( + dynamic_concat["aasequence_str"], + static_concat["aasequence_str"], + minimal_matching_sequence_level_chunk=minimal_matching_sequence_level_chunk, + ) + + # dynamic_indices = dynamic_indices[:50] + # static_indices = static_indices[:50] + + # extract seq-level matching atoms coordinates + dynamic_matching = _apply_indices(dynamic_concat, dynamic_indices) + static_matching = _apply_indices(static_concat, static_indices) + + # calculate the rigid transformation to translate from the starting pose of the dynamic onto the static + + combined_mask = np.logical_and( + dynamic_matching["atom14_gt_exists"].astype(bool), + static_matching["atom14_gt_exists"].astype(bool), + ) + # orig_atom_pos_shape = dynamic_matching["atom14_gt_positions"].shape + use_for_static = static_matching["atom14_gt_positions"] + use_for_dynamic = dynamic_matching["atom14_gt_positions"] + if backbone_only_based: + use_for_static = use_for_static[:, :4, ...] + use_for_dynamic = use_for_dynamic[:, :4, ...] + + _, rmsd, rot_matrix, trans_matrix = superimpose( + use_for_static.reshape(-1, 3), + use_for_dynamic.reshape(-1, 3), + combined_mask.reshape(-1), + verbose=True, + ) + + assert rot_matrix.shape == (1, 3, 3) + rot_matrix = rot_matrix[0] + + assert trans_matrix.shape == (1, 3) + trans_matrix = trans_matrix[0] + + assert len(rmsd.shape) == 0 + + if rmsd > 6.0: + warn( + f"flexible_align_chains_structure: got a pretty high rmsd={rmsd} in alignment. Either the structures are very different or the sequence alignment was suboptimal." + ) + + # apply the rigid transformation on the chains described in `apply_rigid_transformation_to_dynamic_chain_ids` argument + transformed_dynamic_atom_pos = {} + for chain_id, prot in apply_rigid_on_dynamic_chains.items(): + _atom_pos_orig_shape = prot["atom14_gt_positions"].shape + _atom_pos_flat = prot["atom14_gt_positions"].reshape(-1, 3) + _atom_pos_flat_transformed = np.dot(_atom_pos_flat, rot_matrix) + trans_matrix + _atom_pos_transformed = _atom_pos_flat_transformed.reshape( + *_atom_pos_orig_shape + ) + transformed_dynamic_atom_pos[chain_id] = _atom_pos_transformed + + # transformed_dynamic_atom_pos[chain_id] = prot['atom14_gt_positions'] + + save_structure_file( + output_filename_extensionless=output_pdb_filename_extentionless, + pdb_id="unknown", + chain_to_atom14=transformed_dynamic_atom_pos, + chain_to_aa_str_seq={ + chain_id: apply_rigid_on_dynamic_chains[chain_id]["aasequence_str"] + for chain_id in apply_rigid_on_dynamic_chains.keys() + }, + chain_to_aa_index_seq={ + chain_id: apply_rigid_on_dynamic_chains[chain_id]["aatype"] + for chain_id in apply_rigid_on_dynamic_chains.keys() + }, + save_cif=False, + mask=None, # TODO: check + ) + + +def _apply_indices(x: Dict, indices: np.ndarray) -> Tuple[str, np.ndarray]: + ans = {} + for k, d in x.items(): + if isinstance(d, str): + ans[k] = "".join(d[i] for i in indices) + else: + ans[k] = d[indices] + return ans + + +def get_alignment_indices( + target: str, + query: str, + minimal_matching_sequence_level_chunk: Optional[int] = None, +) -> Tuple[np.ndarray, np.ndarray]: + aligner = Align.PairwiseAligner() + + ###https://biopython.org/docs/1.75/api/Bio.Align.html#Bio.Align.PairwiseAlignment + ### https://github.com/biopython/biopython/blob/master/Bio/Align/substitution_matrices/data/README.txt + aligner.substitution_matrix = Align.substitution_matrices.load("BLOSUM62") + + alignments = aligner.align(target, query) + alignment = alignments[0] + + target_indices = [] + query_indices = [] + + for (target_start, target_end), (query_start, query_end) in zip(*alignment.aligned): + if (minimal_matching_sequence_level_chunk is None) or ( + target_end - target_start >= minimal_matching_sequence_level_chunk + ): + target_indices.extend(list(range(target_start, target_end))) + query_indices.extend(list(range(query_start, query_end))) + + if len(target_indices) == 0: + raise Exception( + f"ERROR: in flexible_align_chains_structure(), could not align even a single chunk of minimal defined size {minimal_matching_sequence_level_chunk}" + ) + + target_indices = np.array(target_indices) + query_indices = np.array(query_indices) + + return target_indices, query_indices + + +def _concat_elements_from_dict( + input_dict: Dict, attribute: str +) -> Union[str, np.ndarray]: + # elements = [getattr(p, attribute) for (_, p) in input_dict.items()] + elements = [p[attribute] for (_, p) in input_dict.items()] + ans = _concat_elements(elements) + return ans + + +def _concat_elements(elements: List[Union[str, np.ndarray]]) -> Union[str, np.ndarray]: + assert len(elements) > 0 + if isinstance(elements[0], str): + return "".join(elements) + + ans = np.concatenate(elements, axis=0) + return ans + + +def _to_list(x: Union[str, List]) -> List: + if isinstance(x, str): + x = x.split("@") + x = [tuple(curr.split("^")) for curr in x] + assert isinstance(x, list) + + for element in x: + assert len(element) == 2 + return x + + +if __name__ == "__main__": + CLI(flexible_align_chains_structure) + + +####usage examples + +""" +python $MY_GIT_REPOS/fuse-drug/fusedrug/data/protein/structure/flexible_align_chains_structure.py \ + $MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/antibody_dimer_candidate_with_indels_NOT_aligned.pdb^A \ + $MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/antibody_dimer_candidate_with_indels_NOT_aligned.pdb^A@$MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/antibody_dimer_candidate_with_indels_NOT_aligned.pdb^B \ + $MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/PD1_7VUX_antibody_heavy_chain_from_equalized_reference_complex.pdb^H \ + $MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/output_aligned_candidate_antibody_dimer_only_H_for_alignment + + + +python $MY_GIT_REPOS/fuse-drug/fusedrug/data/protein/structure/flexible_align_chains_structure.py \ + $MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/antibody_dimer_candidate_with_indels_NOT_aligned.pdb^A@$MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/antibody_dimer_candidate_with_indels_NOT_aligned.pdb^B \ + $MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/antibody_dimer_candidate_with_indels_NOT_aligned.pdb^A@$MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/antibody_dimer_candidate_with_indels_NOT_aligned.pdb^B \ + $MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/PD1_7VUX_antibody_heavy_chain_from_equalized_reference_complex.pdb^H@$MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/PD1_7VUX_antibody_light_chain_from_equalized_reference_complex.pdb^L \ + $MY_GIT_REPOS/fuse-drug/fusedrug/tests_data/structure/protein/flexible_align/output_aligned_candidate_antibody_dimer_used_both_LH_for_alignment + + + + +""" diff --git a/fusedrug/data/protein/structure/protein_complex.py b/fusedrug/data/protein/structure/protein_complex.py index 48d02d8a..99288eb1 100644 --- a/fusedrug/data/protein/structure/protein_complex.py +++ b/fusedrug/data/protein/structure/protein_complex.py @@ -22,18 +22,26 @@ def __init__(self, verbose: bool = True) -> None: self.chains_data = {} # maps from chain description (e.g. ('7vux', 'A')) to self.flattened_data = {} + # a key is a tuple in the format (pdb_id, chain_id) + self.per_chain_most_frequent_residue_part = {} + self.per_chain_mmcif_object = {} + self.per_chain_mmcif_dict = {} + def add( self, - pdb_id: str, + pdb_id_or_filename: str, + pdb_id: Optional[str] = None, chain_ids: Optional[List[Union[str, int]]] = None, load_protein_structure_features_overrides: Dict = None, - min_chain_residues_count: int = 10, + min_chain_residues_count: int = 8, max_residue_type_part: float = 0.5, allow_dna_or_rna_in_complex: bool = False, - ) -> None: + chain_id_type: str = "author_assigned", + ) -> bool: """ Args: - pdb_id: for example '7vux' + pdb_id_or_filename: for example '7vux' or '/some/path/to/7vux.cif.gz' + pdb_id: must be provided if pdb_id_or_filename is a filename chain_ids: provide None (default) to load all chains provide a list of chain identifiers to select which are loaded. use str to use chain_id @@ -51,19 +59,19 @@ def add( load_protein_structure_features_overrides = {} ans = load_protein_structure_features( - pdb_id, - pdb_id=pdb_id if len(pdb_id) == 4 else None, + pdb_id_or_filename=pdb_id_or_filename, + pdb_id=pdb_id, chain_id=chain_ids, - also_return_mmcif_object=True, + chain_id_type=chain_id_type, **load_protein_structure_features_overrides, ) if ans is None: if self.verbose: print(f"ProteinComplex::add could not load pdb_id={pdb_id}") - return + return False - loaded_chains, mmcif_object = ans + loaded_chains, mmcif_object, mmcif_dict = ans if not allow_dna_or_rna_in_complex: if mmcif_object.info["rna_or_dna_only_sequences_count"] > 0: @@ -71,11 +79,13 @@ def add( print( f'dna or rna sequences are not allowed, and detected {mmcif_object.info["rna_or_dna_only_sequences_count"]}' ) - return + return False # min_chain_residues_count:int = 10, # max_residue_type_part:float = 0.5, + added_any = False + for k, d in loaded_chains.items(): if min_chain_residues_count is not None: if len(d["aa_sequence_str"]) < min_chain_residues_count: @@ -84,10 +94,14 @@ def add( f"chain {k} is too small, less than {min_chain_residues_count}" ) continue + most_frequent_residue_part = d["aatype"].unique(return_counts=True)[ + 1 + ].max() / len(d["aatype"]) + self.per_chain_most_frequent_residue_part[ + (pdb_id, k) + ] = most_frequent_residue_part.item() if max_residue_type_part is not None: - most_frequent_residue_part = d["aatype"].unique(return_counts=True)[ - 1 - ].max() / len(d["aatype"]) + if most_frequent_residue_part > max_residue_type_part: if self.verbose: print( @@ -96,6 +110,11 @@ def add( continue self.chains_data[(pdb_id, k)] = d + self.per_chain_mmcif_object[(pdb_id, k)] = mmcif_object + self.per_chain_mmcif_dict[(pdb_id, k)] = mmcif_dict + added_any = True + + return added_any def flatten( self, @@ -177,7 +196,7 @@ def spatial_crop( ) -> None: """ Spatial crop of a pair of chains which favors interacting residues. - Note - you must call "flatten" (with only two chain descriptor) prior to calling this method. + Note - you must call "flatten" prior to calling this method. The code is heavily influenced from the spatial crop done in RF2 """ @@ -540,6 +559,59 @@ def calculate_chains_interaction_info( non_interacting_pairs=non_interacting_pairs, ) + def get_main_features(self, atom_representation: int = 14) -> Dict: + """ + Returns a dictionary, extracting from all features based on the requested atom_representation + { + "chain_identifier" : [ list of chain identifiers, each is (pdb_id, chain_id)], + "atom_positions" : [list of numpy arrays each contains atom position and having the shape [residues num, 14 or 37, 3]], + "aa_types" : [list of numpy arrays each containing integer values of amino-acid types, based on tiny_openfold.np.residue_constants.restypes_with_x order], + "gt_atom_exists": [list of numpy arrays each containing a boolean value if that atom exists, in the shape [residues num, 14 or 37]], + TODO: add the str version + } + """ + if atom_representation == 14: + features_names = dict( + atom_positions="atom14_gt_positions", + atom_exists="atom14_gt_exists", + aatype="aatype", + bfactors="atom14_bfactors", # + ) + elif self.atom_representation == 37: + features_names = dict( + atom_positions="all_atom_positions", + atom_exists="all_atom_mask", # it seems that atom37_atom_exists contains valid masks for residues with missing positional data + aatype="aatype", + bfactors="all_atom_bfactors", + ) + else: + raise Exception( + f"Only supported options for atom_representation are 14 and 37 of type integer. Got {atom_representation} of type {type(atom_representation)}" + ) + + ans = {} + + ans["chain_identifier"] = [] + ans["atom_positions"] = [] + ans["aa_types"] = [] + ans["gt_atom_exists"] = [] + ans["aa_sequence_str"] = [] + ans["bfactors"] = [] + ###ans['chain_index'] = [] + + for pdb_id_chain_id, (k, chain_data) in enumerate(self.chains_data.items()): + ans["chain_identifier"].append(k) + ans["atom_positions"].append(chain_data[features_names["atom_positions"]]) + ans["aa_types"].append(chain_data[features_names["aatype"]]) + ans["gt_atom_exists"].append(chain_data[features_names["atom_exists"]]) + ans["bfactors"].append(chain_data[features_names["bfactors"]]) + ans["aa_sequence_str"].append(chain_data["aa_sequence_str"]) + ###ans['chain_index'] += [chain_index]*ans['atom_positions'].shape[0] + + ans["atom_representation"] = atom_representation + + return ans + def calculate_number_of_interacting_residues( *, diff --git a/fusedrug/data/protein/structure/sabdab.py b/fusedrug/data/protein/structure/sabdab.py index 9863fd90..f2752538 100644 --- a/fusedrug/data/protein/structure/sabdab.py +++ b/fusedrug/data/protein/structure/sabdab.py @@ -20,3 +20,40 @@ def load_sabdab_dataframe(path: Optional[str] = None) -> pd.DataFrame: path = os.path.join(os.environ["SABDAB_DIR"], "sabdab_summary_all.tsv") df = pd.read_csv(path, sep="\t") return df + + +class SAbDAb: + """ + A very simplistic class for loading sabdab entries. useful for quick testing and debugging. + For more complex processing and usage please use bmfm-bench + """ + + def __init__(self, main_dataframe_path: str = None): + self.df = load_sabdab_dataframe(main_dataframe_path) + + def get_entry(self, pdb_id: str, heavy_chain_id: Optional[str] = None) -> pd.Series: + if heavy_chain_id is not None: + found = self.df.loc[ + (self.df.pdb == pdb_id) & (self.df.Hchain == heavy_chain_id) + ] + else: + found = self.df.loc[self.df.pdb == pdb_id] + + if found.shape[0] == 0: + raise Exception( + f"could not find an entry for pdb_id={pdb_id} heavy_chain_id={heavy_chain_id}" + ) + elif found.shape[0] > 1: + raise Exception( + f"found multiple entries for pdb_id={pdb_id} heavy_chain_id={heavy_chain_id}" + ) + + found = found.iloc[0] + + return found + + +if __name__ == "__main__": + inst = SAbDAb() + inst.get_entry(pdb_id="7vux", heavy_chain_id="H") + inst.get_entry(pdb_id="7vux") diff --git a/fusedrug/data/protein/structure/structure_io.py b/fusedrug/data/protein/structure/structure_io.py index ea94eeb0..78e68090 100644 --- a/fusedrug/data/protein/structure/structure_io.py +++ b/fusedrug/data/protein/structure/structure_io.py @@ -5,7 +5,7 @@ import torch from copy import deepcopy import pathlib - +from tqdm import trange import numpy as np from Bio.PDB import * # noqa: F401, F403 from Bio.PDB import StructureBuilder @@ -18,7 +18,7 @@ from Bio.PDB.Residue import Residue from Bio.PDB.Atom import Atom from Bio import PDB -from warnings import warn + from tiny_openfold.data import data_transforms from tiny_openfold.utils.tensor_utils import tree_map @@ -56,7 +56,7 @@ def save_structure_file( save_cif: bool = True, b_factors: Optional[Dict[str, torch.Tensor]] = None, reference_cif_filename: Optional[str] = None, - mask: Optional[List] = None, + mask: Optional[Dict[str, List]] = None, ) -> List[str]: """ A helper function allowing to save single or multi chain structure into pdb and/or mmcif format. @@ -72,7 +72,7 @@ def save_structure_file( save_cif - should it store mmCIF format (newer, and no length limits) b_factors - reference_cif_filename:Optional[str] - for mmCIF outputs you must provide an mmCIF reference file (you can use the ground truth one) - mask:Optional[List] - a mask describing which residues to store + mask: - an optional dictionary mapping chain_id to *residue-level* mask Returns: A list with paths for all saved files @@ -118,6 +118,10 @@ def save_structure_file( for chain_id in sorted_chain_ids: pos_atom14 = chain_to_atom14[chain_id] + if mask is not None: + curr_mask = mask[chain_id] + else: + curr_mask = torch.full((pos_atom14.shape[0],), fill_value=True) if save_pdb: out_pdb = output_filename_extensionless + "_chain_" + chain_id + ".pdb" @@ -136,9 +140,7 @@ def save_structure_file( if b_factors is not None else torch.tensor([100.0] * pos_atom14.shape[0]), sequence=chain_to_aa_index_seq[chain_id], - residues_mask=mask - if mask is not None - else torch.full((pos_atom14.shape[0],), fill_value=True), + residues_mask=curr_mask, save_path=out_pdb, init_chain=potentially_fixed_chain_id, model=0, @@ -174,7 +176,6 @@ def load_protein_structure_features( chain_id_type: str = "author_assigned", device: str = "cpu", max_allowed_file_size_mbs: float = None, - also_return_mmcif_object: bool = False, ) -> Union[Tuple[str, dict], None]: """ Extracts ground truth features from a given pdb_id or filename. @@ -249,8 +250,11 @@ def load_protein_structure_features( return None elif structure_file_format == "cif": try: - mmcif_object = parse_mmcif( - native_structure_filename, unique_file_id=pdb_id, quiet_parsing=True + mmcif_object, mmcif_dict = parse_mmcif( + native_structure_filename, + unique_file_id=pdb_id, + quiet_parsing=True, + also_return_mmcif_dict=True, ) chains_names = list(mmcif_object.chain_to_seqres.keys()) except Exception as e: @@ -308,9 +312,11 @@ def load_protein_structure_features( "aatype", "all_atom_positions", "all_atom_mask", + "all_atom_bfactors", "resolution", "residue_index", "chain_index", + "all_atom_bfactors", ] } @@ -336,8 +342,7 @@ def load_protein_structure_features( else: final_ans = ans[chain_id] - if also_return_mmcif_object: - final_ans = (final_ans, mmcif_object) + final_ans = (final_ans, mmcif_object, mmcif_dict) return final_ans @@ -360,6 +365,7 @@ def calculate_additional_features(gt_mmcif_feats: Dict) -> Dict: gt_mmcif_feats = data_transforms.make_atom14_masks(gt_mmcif_feats) gt_mmcif_feats = data_transforms.make_atom14_positions(gt_mmcif_feats) + gt_mmcif_feats = data_transforms.make_atom14_bfactors(gt_mmcif_feats) # for reference, remember .../openfold/openfold/data/input_pipeline.py # data_transforms.make_atom14_masks @@ -427,8 +433,27 @@ def structure_from_pdb(pdb_filename: str) -> Structure: return structure +def load_pdb_chain_features( + filename: str, + chain_id: Optional[str] = None, + also_return_openfold_protein: bool = False, +) -> Dict: + prot = pdb_to_openfold_protein( + filename, + chain_id, + ) + + features = convert_openfold_protein_to_dict(prot) + features = calculate_additional_features(features) + + if also_return_openfold_protein: + return features, prot + return features + + def pdb_to_openfold_protein( - filename: str, chain_id: Optional[str] = None + filename: str, + chain_id: Optional[str] = None, ) -> protein_utils.Protein: """ Loads data from the pdb file - which includes the atoms positions, atom mask, the AA sequence. @@ -445,6 +470,35 @@ def pdb_to_openfold_protein( # return protein_utils.from_pdb_string(f.read(), chain_id=chain_id) +def convert_openfold_protein_to_dict( + prot: protein_utils.Protein, to_torch: bool = True +) -> Dict: + """ + Note: Aligning with the mmcif code expected names + """ + + names_mapping = { # Protin to expected keys in dict + "atom_positions": "all_atom_positions", + "aatype": "aatype", + "atom_mask": "all_atom_mask", + #'residue_index' : 'residue_index', + #'b_factors' : , + #'chain_index' : , + #'remark' : , + #'parents' : , + #'parents_chain_index' : , + "aasequence_str": "aasequence_str", + } + + ans = {} + for from_name, to_name in names_mapping.items(): + ans[to_name] = getattr(prot, from_name) + if to_torch and (not isinstance(ans[to_name], str)): + ans[to_name] = torch.from_numpy(ans[to_name]) + + return ans + + def get_available_chain_ids_in_pdb(filename: str) -> List[str]: """ Will return all available chain ids in a pdb file, performs some filtering to get protein chains and not other chains @@ -566,6 +620,7 @@ def save_trajectory_to_pdb_file( save_path: str, traj_b_factors: torch.Tensor = None, init_chain: str = "A", + verbose: bool = False, ) -> None: """ Stores a trajectory into a single PDB file. @@ -590,7 +645,9 @@ def save_trajectory_to_pdb_file( builder = StructureBuilder.StructureBuilder() builder.init_structure(0) - for model in range(traj_xyz.shape[0]): + use_range_func = trange if verbose else range + + for model in use_range_func(traj_xyz.shape[0]): builder.init_model(model) builder.init_chain(init_chain) builder.init_seg(" ") @@ -599,13 +656,19 @@ def save_trajectory_to_pdb_file( xyz = traj_xyz[model] b_factors = traj_b_factors[model] + if torch.is_tensor(residues_mask): + residues_mask = residues_mask.bool() + else: + residues_mask = residues_mask.astype(bool) + for i, (aa_idx, p_res, b, m_res) in enumerate( - zip(sequence, xyz, b_factors, residues_mask.bool()) + zip(sequence, xyz, b_factors, residues_mask) ): if not m_res: continue aa_idx = aa_idx.item() - p_res = p_res.clone().detach().cpu() # fixme: this looks slow + if torch.is_tensor(p_res): + p_res = p_res.clone().detach().cpu() # fixme: this looks slow if aa_idx == 21: continue try: @@ -694,41 +757,59 @@ def flexible_save_pdb_file( ) xyz = xyz[:, :4, ...] - elif xyz.shape[1] != 14: - warn( - f"flexible_save_pdb_file:: info: note that xyz contains {xyz.shape[1]} max atoms, and not max 14 atoms (all possible heavy atoms). This is ok if intentional, for example, when outputting only backbone." - ) + assert xyz.shape[1] in [ + 4, + 14, + 37, + ], f"xyz shape is allowed to be 14 (all heavy atoms) or 4 (only BB), got xyz.shap={xyz.shape}" if b_factors is None: - b_factors = torch.tensor([100.0] * xyz.shape[0]) + # b_factors = torch.tensor([100.0] * xyz.shape[0]) + b_factors = torch.zeros((xyz.shape[:-1])) builder = StructureBuilder.StructureBuilder() builder.init_structure(0) builder.init_model(model) builder.init_chain(init_chain) builder.init_seg(" ") + if torch.is_tensor(residues_mask): + residues_mask = residues_mask.bool() + + if torch.is_tensor(xyz): + xyz = xyz.clone().detach().cpu() + for i, (aa_idx, p_res, b, m_res) in enumerate( - zip(sequence, xyz, b_factors, residues_mask.bool()) + zip(sequence, xyz, b_factors, residues_mask) ): if not m_res: continue aa_idx = aa_idx.item() - p_res = p_res.clone().detach().cpu() # fixme: this looks slow - if aa_idx == 21: + + if aa_idx == 21: # is this X ? (unknown/special) continue try: three = residx_to_3(aa_idx) except IndexError: continue builder.init_residue(three, " ", int(i), icode=" ") - for j, (atom_name,) in enumerate( - zip(rc.restype_name_to_atom14_names[three]) - ): # why is zip used here? - if (len(atom_name) > 0) and (len(p_res) > j): + + if xyz.shape[1] == 37: + atom_names = rc.atom_types + else: + atom_names = rc.restype_name_to_atom14_names[three] + + residue_atom_names = rc.residue_atoms[three] + + for j, (atom_name,) in enumerate(zip(atom_names)): # why is zip used here? + if ( + (len(atom_name) > 0) + and (len(p_res) > j) + and atom_name in residue_atom_names + ): builder.init_atom( atom_name, p_res[j].tolist(), - b.item(), + b[j].item(), 1.0, " ", atom_name.join([" ", " "]), @@ -739,6 +820,7 @@ def flexible_save_pdb_file( io.set_structure(structure) os.makedirs(pathlib.Path(save_path).parent, exist_ok=True) io.save(save_path) + pass def save_pdb_file( diff --git a/fusedrug/eval/metrics/protein_sequences.py b/fusedrug/eval/metrics/protein_sequences.py index 7257b71b..1891a6e9 100644 --- a/fusedrug/eval/metrics/protein_sequences.py +++ b/fusedrug/eval/metrics/protein_sequences.py @@ -18,7 +18,7 @@ """ from typing import List, Dict, Any from functools import partial - +import pandas as pd from fuse.eval.metrics.metrics_common import MetricPerBatchDefault from Bio import Align import difflib @@ -89,8 +89,8 @@ def _pairwise_protein_sequence_alignment_compute( def _pairwise_aligned_score(preds: List[str], target: List[str]) -> List[float]: - assert isinstance(preds, list) - assert isinstance(target, list) + assert isinstance(preds, (list, pd.Series)) + assert isinstance(target, (list, pd.Series)) assert len(preds) == len(target) penalty_score = 0.0