Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalars support #132

Merged
merged 6 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions fusedrug/data/tokenizer/injectortokenizer/injector_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
from fusedrug.data.tokenizer.modulartokenizer.modular_tokenizer import ModularTokenizer
from typing import Optional, List, Tuple, Dict
from tokenizers import Encoding
import torch
import re
from fuse.utils import NDict


class InjectorTokenizer(ModularTokenizer):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

originally I wanted to use this as a drop-in replacement for ModularTokenizer, but in order to avoid code duplication I ended up only storing static methods here.

"""
InjectorTokenizer builds on top of ModularTokenizer.

Its purpose is to extend beyond "standard" input tokens as integers as input for a model.
Instead, it provides control on *vectors* that are to be used as input for a model.

Example use cases:
1. Providing scalars (floating point) as inputs
2. Providing vectors of embeddings - for example of a protein embedding

Each input "token" becomes a tensor of a defined size, and is built of:
1. Header
made of 4 floats
[
0.0 or 1.0 #is this a sentinel/mask or not
0.0 or 1.0 #is this a standard vocabulary token
0.0 or 1.0 #is this a scalar
0.0 or 1.0 #is this a full injected vector (e.g. an embedding)
]
2. Content
the rest of each input vector is made of input_dim-4 float elements.


Note - in the "standard vocabulary token" - we support providing an external embeding layer (like in vanilla T5),
as it's part of the trained weights.

"""

@staticmethod
def build_placeholder_meta_tokenization(
*,
sequence: str,
sample_dict: Optional[NDict] = None,
) -> Tuple[str, List[str]]:
"""
In order to avoid modifying and rewriting the logic in modular tokenizer, especially regarding padding, limitation of max length of certain sub-parts,
we put placeholders to make sure that the total size is known/fixed and respects the meta instructions to the modular tokenizer

Returns: a tuple with 2 elements
(
a single string with the full query containing placeholder tokens for FLOAT and VECTOR meta tokenizer parts,
a list of [meta-tokenizer name, data, meta-tokenizer name, data, meta-tokenizer name, data, ...]
)
"""
hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence)[
1:
] # the first element is blank - removing it
assert (
len(hints_and_subseq) > 0 and len(hints_and_subseq) % 2 == 0
), f"Error: expecting leading modular tokenizer hints followed by a sequence to tokenize, got {sequence}"

with_placeholders = []

for tokenizer_type, subseq in zip(
hints_and_subseq[::2], hints_and_subseq[1::2]
):
if tokenizer_type.startswith("SCALARS_"):
with_placeholders.append(
"<@TOKENIZER-TYPE=AA>"
) # won't use AA tokens, just an arbitrary one to be able to use a token like <SCALAR>

if (
tokenizer_type == "SCALARS_LITERALS"
): # note: masking is only supported in literals (not in "from dict")
values = subseq.split(",")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we should write "," and not "?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the scalars tokenizer require that you split them with ','
if you have an alternative you prefer do suggest.
I will add some description of the expected format in the injector files docstrings

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added docstrings with format description for both injector_tokenizer.py and injector_tokenizer_op

also renamed InjectorTokenizer to InjectorTokenizerHelpers and stopped inheriting from ModularTokenizer in it because it's misleading, as it's just 2 static helper methods.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is part of the docstrings I've added

applies a injector tokenizer

    injector tokenizer builds on top of modular tokenizer.
    its purpose is to build inputs_emb for the model (instead of input_ids)
        this allows to support more advanced inputs beyond token ids, like:
        * scalars inputs
        * embeddings vector within a single input

    supported syntax/format:

    for text following <@TOKENIZER-TYPE=SCALARS_LITERALS> supports the following format:
    ',' separated float values and/or <MASK> tokens - 
        for example: "2.7,3.99,-12.9" or "<MASK><MASK>" or "2.19,<MASK>,3.19,<MASK>"
    
    for text following <@TOKENIZER-TYPE=SCALARS_FROM_DICT> is expected to be a key to the sample NDict
        for example: "blah.boo.banana"  or "data.input.encoder_input"
        note: in SCALARS_FROM_DICT you can't describe masked scalars (outputs) you can only describe inputs

    example usage:

    encoder_input:
    <@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS><MASK><@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>
    labels:
    <@TOKENIZER-TYPE=AA><MOLECULAR_WEIGHT_IN_SOME_UNIT><@TOKENIZER-TYPE=SCALARS_LITERALS>0.3<@TOKENIZER-TYPE=AA><BINDING_AFFINITY_NANOMOLAR><@TOKENIZER-TYPE=SCALARS_LITERALS>12.4<@TOKENIZER-TYPE=AA><SEQUENCE_NATURAL_START>ISGGDAIYSSTGRCSLGFNVRSGSTYYFLTAGICTDGATTWWANSARTTVLGTTSGSSFPNNDYGIVRYTNTTIPKDGTVGGQDITSAANATVGMAVTRRGSTTGTISGSVTALNATVNYGGGDVVYGMIRTNVCAEPGDSGGPLYSGTRAIGLTSGGSGNCSSGGTTFFQPVTEALVAYGVSVY<SEQUENCE_NATURAL_END>

# seq = "<SCALAR>" * len(values)
seq = "".join(
[
"<MASKED_SCALAR>" if x == "<MASK>" else "<SCALAR>"
for x in values
]
)
elif tokenizer_type == "SCALARS_FROM_DICT":
if sample_dict is None:
raise Exception(
"SCALARS_FROM_DICT used but the provided sample_dict is None"
)
values = sample_dict[subseq]
assert len(values.shape) == 1
seq = "<SCALAR>" * len(values)
else:
raise Exception(f"tokenizer_type={tokenizer_type} is not supported")

# elif tokenizer_type == "SCALARS_MASKED":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

# values = subseq.split(",")
# assert all([x=='<MASK>' for x in values]) #only <MASK> is currently supported
# seq = "<MASKED_SCALAR>" * len(values)

with_placeholders.append(seq)

elif tokenizer_type.startswith("VECTORS_"):
raise Exception("VECTOR_* are not supported yet")
else:
with_placeholders.append("<@TOKENIZER-TYPE=" + tokenizer_type + ">")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might mistakenly drop here the max length per element.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so:

sequence = "<@TOKENIZER-TYPE=AA><BLAH><BLAH2>QKPGQAPRLLIYG<@TOKENIZER-TYPE=AA@MAX-LEN=122><BLAH3>SGSDFSDFSFD"
hints_and_subseq = re.split("<@TOKENIZER-TYPE=([^>]*)>", sequence)[1]
In [6]: hints_and_subseq
Out[6]: ['AA', '<BLAH><BLAH2>QKPGQAPRLLIYG', 'AA@MAX-LEN=122', '<BLAH3>SGSDFSDFSFD']

tell me if you still think I miss something here

with_placeholders.append(subseq)

return "".join(with_placeholders), hints_and_subseq

@staticmethod
def prepare_info_for_model_step(
*,
per_meta_tokenizer_data: List[str],
per_meta_encoding_including_placeholders: List[Encoding],
sample_dict: Optional[NDict] = None,
) -> Dict:
"""
since we:
1. Need to use the model embedding layer (allowing gradients flow if needed)
2. We prefer not to use the model during the data pipeline

In this function we prepare everything so that during the train/val/test_step we'll be able to do what's needed before doing the forward pass

Args:
per_meta_tokenizer_data: a list of [meta-tokenizer name, data, meta-tokenizer name, data, meta-tokenizer name, data, ...]
per_meta_encoding_including_placeholders: a list of Encoding elements. This is used to extract per tokenizer final tokens num (after all of the padding and cropping logic was already done)
sample_dict: a fuse sample_dict - optional.
needed only if the meta tokenizer instruction uses a syntax of lookup from the dictionary


"""
scalars_indices = []
scalars_values = []
scalars_masked_indices = []
prev_index_end = -1

for tokenizer_name, curr_str_data, curr_placeholder_encoding in zip(
per_meta_tokenizer_data[::2],
per_meta_tokenizer_data[1::2],
per_meta_encoding_including_placeholders,
):
if tokenizer_name.startswith("SCALARS_"):
if "SCALARS_LITERALS" == tokenizer_name:
curr_str_data = curr_str_data.strip().split(",")
if len(curr_str_data) != len(curr_placeholder_encoding.ids):
raise Exception(
f"should match expected length. Found length {len(curr_str_data)} but placeholders length was {len(curr_placeholder_encoding.ids)}"
)

curr_indices = []
curr_data = []

for i, val in enumerate(curr_str_data):
if val != "<MASK>":
curr_indices.append(i + prev_index_end + 1)
curr_data.append(float(val))
else:
scalars_masked_indices.append(i + prev_index_end + 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is a running index of scalars and index that aligns it to the encoder_input.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this collects all of the indices (at the level of final tokens) of masked scalars, across the entire sequence.
expected to be empty for labels, and possibly non-empty for encoder_input yeah.


if len(curr_indices) > 0:
curr_indices = torch.tensor(curr_data, dtype=torch.int64)
curr_data = torch.tensor(curr_data, dtype=torch.float32)

scalars_indices.append(curr_indices)
scalars_values.append(curr_data)

assert len(curr_data.shape) == 1
elif "SCALARS_FROM_DICT" == tokenizer_name:
if sample_dict is None:
raise Exception(
"SCALARS_FROM_DICT used but the provided sample_dict is None"
)
curr_data = sample_dict[curr_str_data]
assert len(curr_data.shape) == 1
curr_indices = torch.arange(
prev_index_end + 1, prev_index_end + 1 + curr_data.shape[0]
)

scalars_indices.append(curr_indices)
scalars_values.append(curr_data)

prev_index_end += curr_data.shape[0]

else:
raise Exception(
"Only supported SCALARS_* tokenizers are SCALARS_LITERALS and SCALARS_FROM_DICT"
)

elif tokenizer_name.startswith("VECTORS_"):
raise NotImplementedError
else:
prev_index_end += len(curr_placeholder_encoding.ids)

if len(scalars_indices) > 0:
scalars_indices = torch.concat(scalars_indices)
scalars_values = torch.concat(scalars_values)

if len(scalars_masked_indices) > 0:
scalars_masked_indices = torch.tensor(
scalars_masked_indices, dtype=torch.int64
)
else:
scalars_masked_indices = None

return {
"scalars_indices": scalars_indices,
"scalars_values": scalars_values,
"scalars_masked_indices": scalars_masked_indices,
}
22 changes: 19 additions & 3 deletions fusedrug/data/tokenizer/modulartokenizer/modular_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,13 @@ def encode_list(
return_overflow_info: Optional[bool] = False,
on_unknown: Optional[str] = "warn",
verbose: int = 1,
) -> Union[Encoding, Tuple[Encoding, str]]:
also_return_split: bool = False,
) -> Union[
Encoding,
Tuple[Encoding, str],
Tuple[Encoding, List[Encoding]],
Tuple[Encoding, str, List[Encoding]],
]:
"""_summary_

Args:
Expand All @@ -1025,6 +1031,7 @@ def encode_list(
on_unknown: (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'
verbose (Optional[int], optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
with full data. Defaults to 1.
also_return_split: defaults to False. If set to True, the return value will also contain a list that contains per meta-tokenizer-instruction element of Encoding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be set to True if we want scalar support? or it's just for debug?
If it used for scalars, can we simply infer it from typed_input_list?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't call this directly, injector_tokenizer_op does it automatically for you.
It's not just for debug
we can't infer it from typed_input_list because we don't know how many tokens will be per tokenizer part (as it's not always 1:1 - there are things like SMILES, and things like cropping/padding)

if we only get the final merged one we can't understand:

  1. which tokens we should replace with , <MASKED_SCALARS>
  2. where are the scalars tokens

the only way we can do that externally is by effectively doing the entire logic of modular tokenizer including actual tokenization, padding, cropping, which is both code duplication and will also be slower.
that's why I preferred to allow to return this "internal split" already calculated variable.

If this isn't completely clear yet let's talk

Returns:
Encoding: _description_
"""
Expand Down Expand Up @@ -1150,9 +1157,15 @@ def encode_list(
f"Unexpected on_unknown value {on_unknown}. Should be 'warn' or 'raise'"
)

if (not return_overflow_info) and (not also_return_split):
return merged_encoding
ans = [merged_encoding]
if return_overflow_info:
return merged_encoding, overflow_info
return merged_encoding
ans += [overflow_info]
if also_return_split:
ans += [encoded_list]

return tuple(ans)

def decode(self, ids: Iterable, skip_special_tokens: Optional[bool] = False) -> str:
"""Receives a list of IDs and returns a string of tokens
Expand Down Expand Up @@ -1190,6 +1203,7 @@ def encode(
return_overflow_info: Optional[bool] = False,
on_unknown: Optional[str] = "warn",
verbose: Optional[int] = 1,
also_return_split: bool = False,
) -> Encoding:
# (self, sequence, pair=None, is_pretokenized=False, add_special_tokens=True)
"""Receives a user-supplied string that contains, in addition to the text that is to be tokenized, special delimiters signifying the type
Expand All @@ -1210,6 +1224,7 @@ def encode(
on_unknown: (Optional[str], optional): What happens if unknown tokens (i.e. ones mapped to <UNK>) are encountered: 'raise' or 'warn'
verbose (int, optional): verbosity level. 0: no notification, 1: warning notification, 2: warning with partial data, 3: warning
with full data. Defaults to 1.
also_return_split: also return the per-meta-instruction encoded parts as a list of Encoding elements
Returns:
Encoding: _description_
str: _description_ information on overflow, if return_overflow_info=True
Expand Down Expand Up @@ -1251,6 +1266,7 @@ def encode(
return_overflow_info=return_overflow_info,
on_unknown=on_unknown,
verbose=verbose,
also_return_split=also_return_split,
)

def get_tokenizer_types(self) -> List:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2747,6 +2747,42 @@
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 305,
"content": "<SCALAR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 306,
"content": "<VECTOR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 307,
"content": "<MASKED_SCALAR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 308,
"content": "<MASKED_VECTOR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
],
"normalizer": null,
Expand Down Expand Up @@ -3067,6 +3103,10 @@
"<MOLECULAR_ENTITY_TCR_DELTA_VAR>": 302,
"<MOLECULAR_ENTITY_TCR_GAMMA_CDR3>": 303,
"<MOLECULAR_ENTITY_TCR_GAMMA_VAR>": 304,
"<SCALAR>": 305,
"<VECTOR>": 306,
"<MASKED_SCALAR>": 307,
"<MASKED_VECTOR>": 308,
"#": 527,
"%": 528,
"(": 529,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2747,6 +2747,42 @@
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 305,
"content": "<SCALAR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 306,
"content": "<VECTOR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 307,
"content": "<MASKED_SCALAR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
},
{
"id": 308,
"content": "<MASKED_VECTOR>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
],
"normalizer": null,
Expand Down Expand Up @@ -3073,6 +3109,10 @@
"<MOLECULAR_ENTITY_TCR_DELTA_VAR>": 302,
"<MOLECULAR_ENTITY_TCR_GAMMA_CDR3>": 303,
"<MOLECULAR_ENTITY_TCR_GAMMA_VAR>": 304,
"<SCALAR>": 305,
"<VECTOR>": 306,
"<MASKED_SCALAR>": 307,
"<MASKED_VECTOR>": 308,
"[CL:0000499]": 3522,
"[CL:2000060]": 3523,
"[CL:0000235]": 3524,
Expand Down
Loading
Loading