Skip to content

Commit

Permalink
Scalars support (#132)
Browse files Browse the repository at this point in the history
mainly related to adding supports for scalars input/output to our mammal
architecture

---------

Co-authored-by: YoelShoshan <[email protected]>
  • Loading branch information
YoelShoshan and YoelShoshan authored Jul 26, 2024
1 parent 818cb8c commit ccf7505
Show file tree
Hide file tree
Showing 16 changed files with 820 additions and 5 deletions.
207 changes: 207 additions & 0 deletions fusedrug/data/tokenizer/injectortokenizer/injector_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from typing import Optional, List, Tuple, Dict
from tokenizers import Encoding
import torch
import re
from fuse.utils import NDict


class InjectorTokenizerHelpers:
"""
InjectorTokenizer builds on top of ModularTokenizer.
!!!!
Note - this file contains only few utility (static) functions for InjectorTokenizerOp
as a user, you are not expected to InjectorTokenizer directly, instead you should use fusedrug.data.tokenizer.ops.injector_tokenizer_ops.InjectorTokenizerOp
!!!!
applies a injector tokenizer
injector tokenizer builds on top of modular tokenizer.
its purpose is to build inputs_emb for the model (instead of input_ids)
this allows to support more advanced inputs beyond token ids, like:
* scalars inputs
* embeddings vector within a single input
supported syntax/format:
for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format:
',' separated float values and/or <MASK> tokens -
for example: "2.7,3.99,-12.9" or "<MASK><MASK>" or "2.19,<MASK>,3.19,<MASK>"
for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict
for example: "blah.boo.banana" or "data.input.encoder_input"
note: in SCALARS_FROM_DICT you can't describe masked scalars (outputs) you can only describe inputs
example usage:
encoder_input:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS><MASK><@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
labels:
<@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS>12.4<@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
"""

@staticmethod
def build_placeholder_meta_tokenization(
*,
sequence: str,
sample_dict: Optional[NDict] = None,
) -> Tuple[str, List[str]]:
"""
In order to avoid modifying and rewriting the logic in modular tokenizer, especially regarding padding, limitation of max length of certain sub-parts,
we put placeholders to make sure that the total size is known/fixed and respects the meta instructions to the modular tokenizer
Returns: a tuple with 2 elements
(
a single string with the full query containing placeholder tokens for FLOAT and VECTOR meta tokenizer parts,
a list of [meta-tokenizer name, data, meta-tokenizer name, data, meta-tokenizer name, data, ...]
)
"""
hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence)[
1:
] # the first element is blank - removing it
assert (
len(hints_and_subseq) > 0 and len(hints_and_subseq) % 2 == 0
), f"Error: expecting leading modular tokenizer hints followed by a sequence to tokenize, got {sequence}"

with_placeholders = []

for tokenizer_type, subseq in zip(
hints_and_subseq[::2], hints_and_subseq[1::2]
):
if tokenizer_type.startswith("SCALARS_"):
with_placeholders.append(
"<@TOKENIZER-TYPE=AA>"
) # won't use AA tokens, just an arbitrary one to be able to use a token like <SCALAR>

if (
tokenizer_type == "SCALARS_LITERALS"
): # note: masking is only supported in literals (not in "from dict")
values = subseq.split(",")
# seq = "<SCALAR>" * len(values)
seq = "".join(
[
"<MASKED_SCALAR>" if x == "<MASK>" else "<SCALAR>"
for x in values
]
)
elif tokenizer_type == "SCALARS_FROM_DICT":
if sample_dict is None:
raise Exception(
"SCALARS_FROM_DICT used but the provided sample_dict is None"
)
values = sample_dict[subseq]
assert len(values.shape) == 1
seq = "<SCALAR>" * len(values)
else:
raise Exception(f"tokenizer_type={tokenizer_type} is not supported")

with_placeholders.append(seq)

elif tokenizer_type.startswith("VECTORS_"):
raise Exception("VECTOR_* are not supported yet")
else:
with_placeholders.append("<@TOKENIZER-TYPE=" + tokenizer_type + ">")
with_placeholders.append(subseq)

return "".join(with_placeholders), hints_and_subseq

@staticmethod
def prepare_info_for_model_step(
*,
per_meta_tokenizer_data: List[str],
per_meta_encoding_including_placeholders: List[Encoding],
sample_dict: Optional[NDict] = None,
) -> Dict:
"""
since we:
1. Need to use the model embedding layer (allowing gradients flow if needed)
2. We prefer not to use the model during the data pipeline
In this function we prepare everything so that during the train/val/test_step we'll be able to do what's needed before doing the forward pass
Args:
per_meta_tokenizer_data: a list of [meta-tokenizer name, data, meta-tokenizer name, data, meta-tokenizer name, data, ...]
per_meta_encoding_including_placeholders: a list of Encoding elements. This is used to extract per tokenizer final tokens num (after all of the padding and cropping logic was already done)
sample_dict: a fuse sample_dict - optional.
needed only if the meta tokenizer instruction uses a syntax of lookup from the dictionary
"""
scalars_indices = []
scalars_values = []
scalars_masked_indices = []
prev_index_end = -1

for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip(
per_meta_tokenizer_data[::2],
per_meta_tokenizer_data[1::2],
per_meta_encoding_including_placeholders,
):
if tokenizer_name.startswith("SCALARS_"):
if "SCALARS_LITERALS" == tokenizer_name:
curr_str_data = curr_str_data.strip().split(",")
if len(curr_str_data) != len(curr_placeholder_encoding.ids):
raise Exception(
f"should match expected length. Found length {len(curr_str_data)} but placeholders length was {len(curr_placeholder_encoding.ids)}"
)

curr_indices = []
curr_data = []

for i, val in enumerate(curr_str_data):
if val != "<MASK>":
curr_indices.append(i + prev_index_end + 1)
curr_data.append(float(val))
else:
scalars_masked_indices.append(i + prev_index_end + 1)

if len(curr_indices) > 0:
curr_indices = torch.tensor(curr_data, dtype=torch.int64)
curr_data = torch.tensor(curr_data, dtype=torch.float32)

scalars_indices.append(curr_indices)
scalars_values.append(curr_data)

assert len(curr_data.shape) == 1
elif "SCALARS_FROM_DICT" == tokenizer_name:
if sample_dict is None:
raise Exception(
"SCALARS_FROM_DICT used but the provided sample_dict is None"
)
curr_data = sample_dict[curr_str_data]
assert len(curr_data.shape) == 1
curr_indices = torch.arange(
prev_index_end + 1, prev_index_end + 1 + curr_data.shape[0]
)

scalars_indices.append(curr_indices)
scalars_values.append(curr_data)

prev_index_end += curr_data.shape[0]

else:
raise Exception(
"Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT"
)

elif tokenizer_name.startswith("VECTORS_"):
raise NotImplementedError
else:
prev_index_end += len(curr_placeholder_encoding.ids)

if len(scalars_indices) > 0:
scalars_indices = torch.concat(scalars_indices)
scalars_values = torch.concat(scalars_values)

if len(scalars_masked_indices) > 0:
scalars_masked_indices = torch.tensor(
scalars_masked_indices, dtype=torch.int64
)
else:
scalars_masked_indices = None

return {
"scalars_indices": scalars_indices, # 1d - its length is the number of actual scalars (provided) found
"scalars_values": scalars_values, # 1d - values of provided scalars
"scalars_masked_indices": scalars_masked_indices, # 1d - indices of masked scalars
}
22 changes: 19 additions & 3 deletions fusedrug/data/tokenizer/modulartokenizer/modular_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,13 @@ def encode_list(
return_overflow_info: Optional[bool] = False,
on_unknown: Optional[str] = "warn",
verbose: int = 1,
) -> Union[Encoding, Tuple[Encoding, str]]:
also_return_split: bool = False,
) -> Union[
Encoding,
Tuple[Encoding, str],
Tuple[Encoding, List[Encoding]],
Tuple[Encoding, str, List[Encoding]],
]:
"""_summary_
Args:
Expand All @@ -1025,6 +1031,7 @@ def encode_list(
on_unknown: (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'
verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
with full data. Defaults to 1.
also_return_split: defaults to False. If set to True, the return value will also contain a list that contains per meta-tokenizer-instruction element of Encoding
Returns:
Encoding: _description_
"""
Expand Down Expand Up @@ -1150,9 +1157,15 @@ def encode_list(
f"Unexpected on_unknown value {on_unknown}. Should be 'warn' or 'raise'"
)

if (not return_overflow_info) and (not also_return_split):
return merged_encoding
ans = [merged_encoding]
if return_overflow_info:
return merged_encoding, overflow_info
return merged_encoding
ans += [overflow_info]
if also_return_split:
ans += [encoded_list]

return tuple(ans)

def decode(self, ids: Iterable, skip_special_tokens: Optional[bool] = False) -> str:
"""Receives a list of IDs and returns a string of tokens
Expand Down Expand Up @@ -1190,6 +1203,7 @@ def encode(
return_overflow_info: Optional[bool] = False,
on_unknown: Optional[str] = "warn",
verbose: Optional[int] = 1,
also_return_split: bool = False,
) -> Encoding:
# (self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True)
"""Receives a user-supplied string that contains, in addition to the text that is to be tokenized, special delimiters signifying the type
Expand All @@ -1210,6 +1224,7 @@ def encode(
on_unknown: (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'
verbose (int, optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
with full data. Defaults to 1.
also_return_split: also return the per-meta-instruction encoded parts as a list of Encoding elements
Returns:
Encoding: _description_
str: _description_ information on overflow, if return_overflow_info=True
Expand Down Expand Up @@ -1251,6 +1266,7 @@ def encode(
return_overflow_info=return_overflow_info,
on_unknown=on_unknown,
verbose=verbose,
also_return_split=also_return_split,
)

def get_tokenizer_types(self) -> List:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2747,6 +2747,42 @@
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 305,
"content": "<SCALAR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 306,
"content": "<VECTOR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 307,
"content": "<MASKED_SCALAR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 308,
"content": "<MASKED_VECTOR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
],
"normalizer": null,
Expand Down Expand Up @@ -3067,6 +3103,10 @@
"<MOLECULAR_ENTITY_TCR_DELTA_VAR>": 302,
"<MOLECULAR_ENTITY_TCR_GAMMA_CDR3>": 303,
"<MOLECULAR_ENTITY_TCR_GAMMA_VAR>": 304,
"<SCALAR>": 305,
"<VECTOR>": 306,
"<MASKED_SCALAR>": 307,
"<MASKED_VECTOR>": 308,
"#": 527,
"%": 528,
"(": 529,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2747,6 +2747,42 @@
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 305,
"content": "<SCALAR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 306,
"content": "<VECTOR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 307,
"content": "<MASKED_SCALAR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 308,
"content": "<MASKED_VECTOR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
],
"normalizer": null,
Expand Down Expand Up @@ -3073,6 +3109,10 @@
"<MOLECULAR_ENTITY_TCR_DELTA_VAR>": 302,
"<MOLECULAR_ENTITY_TCR_GAMMA_CDR3>": 303,
"<MOLECULAR_ENTITY_TCR_GAMMA_VAR>": 304,
"<SCALAR>": 305,
"<VECTOR>": 306,
"<MASKED_SCALAR>": 307,
"<MASKED_VECTOR>": 308,
"[CL:0000499]": 3522,
"[CL:2000060]": 3523,
"[CL:0000235]": 3524,
Expand Down
Loading

0 comments on commit ccf7505

Please sign in to comment.