diff --git a/netam/dxsm.py b/netam/dxsm.py index 4ac7a7a4..ebfbf7c0 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -32,7 +32,7 @@ token_mask_of_aa_idxs, MAX_AA_TOKEN_IDX, RESERVED_TOKEN_REGEX, - AA_AMBIG_IDX, + AA_PADDING_TOKEN, ) @@ -134,7 +134,7 @@ def of_seriess( # We have sequences of varying length, so we start with all tensors set # to the ambiguous amino acid, and then will fill in the actual values # below. - aa_parents_idxss = torch.full((pcp_count, max_aa_seq_len), AA_AMBIG_IDX) + aa_parents_idxss = torch.full((pcp_count, max_aa_seq_len), AA_PADDING_TOKEN) aa_children_idxss = aa_parents_idxss.clone() aa_subs_indicators = torch.zeros((pcp_count, max_aa_seq_len)) diff --git a/netam/models.py b/netam/models.py index 1785661c..ee43e019 100644 --- a/netam/models.py +++ b/netam/models.py @@ -23,6 +23,7 @@ aa_idx_tensor_of_str_ambig, PositionalEncoding, split_heavy_light_model_outputs, + AA_PADDING_TOKEN, ) from typing import Tuple @@ -795,6 +796,165 @@ def predict(self, representation: Tensor): return wiggle(super().predict(representation), beta) +def reverse_padded_tensors(padded_tensors, padding_mask, padding_value, reversed_dim=1): + """Reverse the valid values in provided padded_tensors along the specified + dimension, keeping padding in the same place. For example, if the input is left- + aligned amino acid sequences and masks, move the padding to the right of the + reversed sequence. Equivalent to right-aligning the sequences then reversing them. A + sequence `123456XXXXX` becomes `654321XXXXX`. + + The original padding mask remains valid for the returned tensor. + + Args: + padded_tensors: (B, L) tensor of amino acid indices + padding_mask: (B, L) tensor of masks, with True indicating valid values, and False indicating padding values. + padding_value: The value to fill returned tensor where padding_mask is False. + reversed_dim: The dimension along which to reverse the tensor. When input is a batch of sequences to be reversed, the default value of 1 is the correct choice. + Returns: + The reversed tensor, with the same shape as padded_tensors, and with padding still specified by padding_mask. + """ + reversed_indices = torch.full_like(padded_tensors, padding_value) + reversed_indices[padding_mask] = padded_tensors.flip(reversed_dim)[ + padding_mask.flip(reversed_dim) + ] + return reversed_indices + + +class BidirectionalTransformerBinarySelectionModel(AbstractBinarySelectionModel): + def __init__( + self, + nhead: int, + d_model_per_head: int, + dim_feedforward: int, + layer_count: int, + dropout_prob: float = 0.5, + output_dim: int = 1, + known_token_count: int = MAX_AA_TOKEN_IDX + 1, + ): + super().__init__() + self.known_token_count = known_token_count + self.d_model_per_head = d_model_per_head + self.d_model = d_model_per_head * nhead + self.nhead = nhead + self.dim_feedforward = dim_feedforward + # Forward direction components + self.forward_pos_encoder = PositionalEncoding(self.d_model, dropout_prob) + self.forward_amino_acid_embedding = nn.Embedding( + self.known_token_count, self.d_model + ) + self.forward_encoder_layer = nn.TransformerEncoderLayer( + d_model=self.d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True, + ) + self.forward_encoder = nn.TransformerEncoder( + self.forward_encoder_layer, layer_count + ) + + # Reverse direction components + self.reverse_pos_encoder = PositionalEncoding(self.d_model, dropout_prob) + self.reverse_amino_acid_embedding = nn.Embedding( + self.known_token_count, self.d_model + ) + self.reverse_encoder_layer = nn.TransformerEncoderLayer( + d_model=self.d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True, + ) + self.reverse_encoder = nn.TransformerEncoder( + self.reverse_encoder_layer, layer_count + ) + + # Output layers + self.combine_features = nn.Linear(2 * self.d_model, self.d_model) + self.output = nn.Linear(self.d_model, output_dim) + + self.init_weights() + + def init_weights(self) -> None: + initrange = 0.1 + self.combine_features.bias.data.zero_() + self.combine_features.weight.data.uniform_(-initrange, initrange) + self.output.bias.data.zero_() + self.output.weight.data.uniform_(-initrange, initrange) + + def single_direction_represent_sequence( + self, + indices: Tensor, + mask: Tensor, + embedding: nn.Embedding, + pos_encoder: PositionalEncoding, + encoder: nn.TransformerEncoder, + ) -> Tensor: + """Process sequence through one direction of the model.""" + embedded = embedding(indices) * math.sqrt(self.d_model) + embedded = pos_encoder(embedded.permute(1, 0, 2)).permute(1, 0, 2) + return encoder(embedded, src_key_padding_mask=~mask) + + def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: + # This is okay, as long as there are no masked ambiguities in the + # interior of the sequence... Otherwise it should also work for paired seqs. + + # Forward direction - normal processing + forward_repr = self.single_direction_represent_sequence( + amino_acid_indices, + mask, + self.forward_amino_acid_embedding, + self.forward_pos_encoder, + self.forward_encoder, + ) + + # Reverse direction - flip sequences and masks + reversed_indices = reverse_padded_tensors( + amino_acid_indices, mask, AA_PADDING_TOKEN + ) + + reverse_repr = self.single_direction_represent_sequence( + reversed_indices, + mask, + self.reverse_amino_acid_embedding, + self.reverse_pos_encoder, + self.reverse_encoder, + ) + + # un-reverse to align with forward representation + aligned_reverse_repr = reverse_padded_tensors(reverse_repr, mask, 0.0) + + # Combine features + combined = torch.cat([forward_repr, aligned_reverse_repr], dim=-1) + return self.combine_features(combined) + + def predict(self, representation: Tensor) -> Tensor: + # Output layer + return self.output(representation).squeeze(-1) + + def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: + return self.predict(self.represent(amino_acid_indices, mask)) + + @property + def hyperparameters(self): + return { + "nhead": self.nhead, + "d_model_per_head": self.d_model_per_head, + "dim_feedforward": self.dim_feedforward, + "layer_count": self.forward_encoder.num_layers, + "dropout_prob": self.forward_pos_encoder.dropout.p, + "output_dim": self.output.out_features, + "known_token_count": self.known_token_count, + } + + +class BidirectionalTransformerBinarySelectionModelWiggleAct( + BidirectionalTransformerBinarySelectionModel +): + """Here the beta parameter is fixed at 0.3.""" + + def predict(self, representation: Tensor): + return wiggle(super().predict(representation), 0.3) + + class SingleValueBinarySelectionModel(AbstractBinarySelectionModel): """A one parameter selection model as a baseline.""" diff --git a/netam/sequences.py b/netam/sequences.py index 01d99179..9263d68b 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -22,6 +22,9 @@ NT_STR_SORTED = "".join(BASES) BASES_AND_N_TO_INDEX = {base: idx for idx, base in enumerate(NT_STR_SORTED + "N")} AA_AMBIG_IDX = len(AA_STR_SORTED) +# Used for padding amino acid sequences to the same length. Differentiated by +# name in case we add a padding token other than AA_AMBIG_IDX later. +AA_PADDING_TOKEN = AA_AMBIG_IDX CODONS = ["".join(codon_list) for codon_list in itertools.product(BASES, repeat=3)] STOP_CODONS = ["TAA", "TAG", "TGA"] diff --git a/tests/test_netam.py b/tests/test_netam.py index c2fe9088..c0faed11 100644 --- a/tests/test_netam.py +++ b/tests/test_netam.py @@ -8,7 +8,12 @@ import netam.framework as framework from netam.common import BIG from netam.framework import SHMoofDataset, SHMBurrito, RSSHMBurrito -from netam.models import SHMoofModel, RSSHMoofModel, IndepRSCNNModel +from netam.models import ( + SHMoofModel, + RSSHMoofModel, + IndepRSCNNModel, + reverse_padded_tensors, +) @pytest.fixture @@ -114,3 +119,32 @@ def test_standardize_model_rates(mini_rsburrito): mini_rsburrito.standardize_model_rates() vrc01_rate_14 = mini_rsburrito.vrc01_site_14_model_rate() assert np.isclose(vrc01_rate_14, 1.0) + + +def test_reverse_padded_tensors(): + # Here we just test that we can apply the function twice and get the + # original input back. + test_tensor = torch.tensor( + [ + [1, 2, 3, 4, 0, 0], + [1, 2, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [1, 2, 3, 4, 5, 0], + [1, 2, 3, 4, 5, 6], + [1, 2, 0, 0, 0, 0], + ] + ) + true_reversed = torch.tensor( + [ + [4, 3, 2, 1, 0, 0], + [2, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [5, 4, 3, 2, 1, 0], + [6, 5, 4, 3, 2, 1], + [2, 1, 0, 0, 0, 0], + ] + ) + mask = test_tensor > 0 + reversed_tensor = reverse_padded_tensors(test_tensor, mask, 0) + assert torch.equal(true_reversed, reversed_tensor) + assert torch.equal(test_tensor, reverse_padded_tensors(reversed_tensor, mask, 0))