Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/BiomedSciAI/fuse-drug into …
Browse files Browse the repository at this point in the history
…yoels
  • Loading branch information
YoelShoshan committed May 16, 2024
2 parents 72b2328 + c788d1a commit a0a6c02
Show file tree
Hide file tree
Showing 23 changed files with 642 additions and 30 deletions.
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
[submodule "Contrastive_PLM_DTI"]
path = fusedrug_examples/interaction/drug_target/affinity_prediction/PLM_DTI/Contrastive_PLM_DTI
url = https://github.com/alex-golts/Contrastive_PLM_DTI.git # The original repo is not stable
[submodule "fusedrug_examples/interaction/drug_target/affinity_prediction/PLM_DTI/Contrastive_PLM_DTI"]
path = fusedrug_examples/interaction/drug_target/affinity_prediction/PLM_DTI/Contrastive_PLM_DTI
url = https://github.com/alex-golts/Contrastive_PLM_DTI.git # The original repo is not stable
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(

self._ligands_smi = ligands_smi
# _indexed_table_table_kwargs = dict(
# #seperator='\t',
# #separator='\t',
# #id_column_idx=1,
# allow_access_by_id=True
# )
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Dict
from fuse.utils import NDict
from fuse.data import OpBase, get_sample_id
from fusedrug.utils.file_formats import IndexedTextTable
Expand All @@ -15,6 +15,7 @@ def __init__(
table_file_loc: Optional[str] = None,
index_filename: Optional[str] = None,
id_column_index: int = 0,
rename_columns: Optional[Dict[str, str]] = None,
separator: str = " ",
allow_access_by_id: bool = False, # best leave it at False for large files
**kwargs: dict,
Expand All @@ -24,31 +25,38 @@ def __init__(
the file format is expected to be a text file in which each line is expected to be ' ' separated,
containing the columns named
:param index_filename: index file for the table, if not exist or None, it will recreate the index
:param rename_columns: rename columns from table, when None (default) column names are kept
"""
super().__init__(**kwargs)
self._table_file_loc = table_file_loc
self._id_column_index = id_column_index
self._rename_columns = rename_columns if rename_columns is not None else {}
self._separator = separator
self._allow_access_by_id = allow_access_by_id
self._indexed_text_table = IndexedTextTable(
filename=table_file_loc,
index_filename=index_filename,
seperator=self._separator,
separator=self._separator,
id_column_idx=self._id_column_index,
allow_access_by_id=self._allow_access_by_id,
)

def __call__(
self,
sample_dict: NDict,
key_out_seq: str = "data",
self, sample_dict: NDict, key_out_prefix: Optional[str] = None
) -> NDict:
sid = get_sample_id(sample_dict)
assert isinstance(sid, (int, numpy.int64))
assert isinstance(
sid, (int, numpy.int64, numpy.int32, numpy.uint32, numpy.uint64)
)

_, entry_data = self._indexed_text_table[sid]

for c in entry_data.axes[0]:
sample_dict[f"{key_out_seq}.{c}"] = entry_data[c]
if key_out_prefix is None:
sample_dict[self._rename_columns.get(c, c)] = entry_data[c]
else:
sample_dict[
f"{key_out_prefix}.{self._rename_columns.get(c,c)}"
] = entry_data[c]

return sample_dict
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __call__(

ligands_table = IndexedTextTable(
smiles_path,
seperator="\t",
separator="\t",
first_row_is_columns_names=False,
columns_names=["molecule_sequence", "molecule_id"],
id_column_name="molecule_id",
Expand All @@ -99,7 +99,7 @@ def __call__(

proteins_table = IndexedTextTable(
proteins_path,
seperator="\t",
separator="\t",
first_row_is_columns_names=False,
columns_names=["protein_sequence", "protein_id"],
id_column_name="protein_id",
Expand Down
4 changes: 2 additions & 2 deletions fusedrug/data/molecule/ops/loaders/smi_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(
self,
smi_file_loc: Optional[str] = None,
molecule_id_column_idx: int = 0,
seperator: str = "\t",
separator: str = "\t",
allow_access_by_id: bool = True,
**kwargs: dict
):
Expand All @@ -25,7 +25,7 @@ def __init__(
super().__init__(**kwargs)
self._smi_file_loc = smi_file_loc
self._molecule_id_column_idx = molecule_id_column_idx
self._seperator = seperator
self._separator = separator
self._allow_access_by_id = allow_access_by_id
self._indexed_text_table = IndexedTextTable(
smi_file_loc,
Expand Down
4 changes: 2 additions & 2 deletions fusedrug/data/protein/ops/loaders/tests/test_aa_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fusedrug.data.protein.ops.aa_ops import (
OpToUpperCase,
OpKeepOnlyUpperCase,
) # OpAddSeperator, OpStrToTokenIds, OpTokenIdsToStr, OpMaskRandom, OpCropRandom
) # OpAddSeparator, OpStrToTokenIds, OpTokenIdsToStr, OpMaskRandom, OpCropRandom

from fuse.data import PipelineDefault
import os
Expand Down Expand Up @@ -44,7 +44,7 @@ def test_aa_ops(self) -> None:
{},
),
# (OpRepeat(OpCropRandom, [dict(key_out='data.gt.seq'), dict(key_out='data.input.seq')]), {} ),
# (OpRepeat(OpAddSeperator, [
# (OpRepeat(OpAddSeparator, [
# dict(inputs={'data.gt.seq':'seq'}, outputs='data.gt.seq'),
# dict(inputs={'data.input.seq':'seq'}, outputs='data.input.seq'),
# ]), {}),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import click
from typing import Union
from fusedrug.data.tokenizer.modulartokenizer.modular_tokenizer import ModularTokenizer
from fusedrug.data.tokenizer.modulartokenizer.create_multi_tokenizer import (
test_tokenizer,
Expand All @@ -22,7 +23,7 @@
help="path to write tokenizer in",
)
# # this needs to be run on all the related modular tokenizers
def main(tokenizer_path: str, output_path: str | None) -> None:
def main(tokenizer_path: str, output_path: Union[str, None]) -> None:
print(f"adding special tokens to {tokenizer_path}")
if output_path is None:
output_path = tokenizer_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ def set_field(tokenizers_info_cfg: List, name: str, key: str, val: Any) -> List:
tokenizers_info_cfg = self.tokenizers_info_raw_cfg

if not os.path.exists(path):
os.makedirs(path)
os.makedirs(path, exist_ok=True)
for t_type in self.tokenizers_info:
tokenizer_inst = self.tokenizers_info[t_type]["tokenizer_inst"]
if self.tokenizers_info[t_type]["json_path"] is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2648,6 +2648,60 @@
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 294,
"content": "<COMPLEX_ENTITY>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 295,
"content": "<ALTERNATIVE>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 296,
"content": "<CDR3_REGION>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 297,
"content": "<GENERAL_CHAIN>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 298,
"content": "<SUBMOLECULAR_ENTITY>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 299,
"content": "<MUTATED>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
],
"normalizer": null,
Expand Down Expand Up @@ -2957,6 +3011,12 @@
"<CORRUPTED_AREA_END>": 291,
"<MOLECULAR_ENTITY_MUTATED_PROTEIN_CHAIN>": 292,
"<MOLECULAR_ENTITY_PROTEIN_CHAIN>": 293,
"<COMPLEX_ENTITY>": 294,
"<ALTERNATIVE>": 295,
"<CDR3_REGION>": 296,
"<GENERAL_CHAIN>": 297,
"<SUBMOLECULAR_ENTITY>": 298,
"<MUTATED>": 299,
"#": 527,
"%": 528,
"(": 529,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2648,6 +2648,60 @@
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 294,
"content": "<COMPLEX_ENTITY>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 295,
"content": "<ALTERNATIVE>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 296,
"content": "<CDR3_REGION>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 297,
"content": "<GENERAL_CHAIN>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 298,
"content": "<SUBMOLECULAR_ENTITY>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 299,
"content": "<MUTATED>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
],
"normalizer": null,
Expand Down Expand Up @@ -2963,6 +3017,12 @@
"<CORRUPTED_AREA_END>": 291,
"<MOLECULAR_ENTITY_MUTATED_PROTEIN_CHAIN>": 292,
"<MOLECULAR_ENTITY_PROTEIN_CHAIN>": 293,
"<COMPLEX_ENTITY>": 294,
"<ALTERNATIVE>": 295,
"<CDR3_REGION>": 296,
"<GENERAL_CHAIN>": 297,
"<SUBMOLECULAR_ENTITY>": 298,
"<MUTATED>": 299,
"[CL:0000499]": 3522,
"[CL:2000060]": 3523,
"[CL:0000235]": 3524,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2648,6 +2648,60 @@
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 294,
"content": "<COMPLEX_ENTITY>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 295,
"content": "<ALTERNATIVE>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 296,
"content": "<CDR3_REGION>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 297,
"content": "<GENERAL_CHAIN>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 298,
"content": "<SUBMOLECULAR_ENTITY>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 299,
"content": "<MUTATED>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
],
"normalizer": null,
Expand Down Expand Up @@ -2963,6 +3017,12 @@
"<CORRUPTED_AREA_END>": 291,
"<MOLECULAR_ENTITY_MUTATED_PROTEIN_CHAIN>": 292,
"<MOLECULAR_ENTITY_PROTEIN_CHAIN>": 293,
"<COMPLEX_ENTITY>": 294,
"<ALTERNATIVE>": 295,
"<CDR3_REGION>": 296,
"<GENERAL_CHAIN>": 297,
"<SUBMOLECULAR_ENTITY>": 298,
"<MUTATED>": 299,
"[100130093]": 5000,
"[100133445]": 5001,
"[100286793]": 5002,
Expand Down
Loading

0 comments on commit a0a6c02

Please sign in to comment.