From 75b10673a3d16bfcd73c881c7722b0c6cfb99a7d Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 14 Feb 2025 12:02:52 -0800 Subject: [PATCH 01/10] add bidirectional model by Claude --- netam/models.py | 114 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) diff --git a/netam/models.py b/netam/models.py index 1785661c..f33002e6 100644 --- a/netam/models.py +++ b/netam/models.py @@ -795,6 +795,120 @@ def predict(self, representation: Tensor): return wiggle(super().predict(representation), beta) +class BidirectionalTransformerBinarySelectionModel(AbstractBinarySelectionModel): + def __init__( + self, + nhead: int, + d_model_per_head: int, + dim_feedforward: int, + layer_count: int, + dropout_prob: float = 0.5, + output_dim: int = 1, + known_token_count: int = MAX_AA_TOKEN_IDX + 1, + ): + super().__init__() + self.known_token_count = known_token_count + self.d_model_per_head = d_model_per_head + self.d_model = d_model_per_head * nhead + self.nhead = nhead + self.dim_feedforward = dim_feedforward + # Forward direction components + self.forward_pos_encoder = PositionalEncoding(self.d_model, dropout_prob) + self.forward_amino_acid_embedding = nn.Embedding(self.known_token_count, self.d_model) + self.forward_encoder_layer = nn.TransformerEncoderLayer( + d_model=self.d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True, + ) + self.forward_encoder = nn.TransformerEncoder(self.forward_encoder_layer, layer_count) + + # Reverse direction components + self.reverse_pos_encoder = PositionalEncoding(self.d_model, dropout_prob) + self.reverse_amino_acid_embedding = nn.Embedding(self.known_token_count, self.d_model) + self.reverse_encoder_layer = nn.TransformerEncoderLayer( + d_model=self.d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + batch_first=True, + ) + self.reverse_encoder = nn.TransformerEncoder(self.reverse_encoder_layer, layer_count) + + # Output layers + self.combine_features = nn.Linear(2 * self.d_model, self.d_model) + self.output = nn.Linear(self.d_model, output_dim) + + self.init_weights() + + def init_weights(self) -> None: + initrange = 0.1 + self.combine_features.bias.data.zero_() + self.combine_features.weight.data.uniform_(-initrange, initrange) + self.output.bias.data.zero_() + self.output.weight.data.uniform_(-initrange, initrange) + + def represent_sequence(self, indices: Tensor, mask: Tensor, + embedding: nn.Embedding, pos_encoder: PositionalEncoding, + encoder: nn.TransformerEncoder) -> Tensor: + """Process sequence through one direction of the model.""" + embedded = embedding(indices) * math.sqrt(self.d_model) + embedded = pos_encoder(embedded.permute(1, 0, 2)).permute(1, 0, 2) + return encoder(embedded, src_key_padding_mask=~mask) + + def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: + batch_size, seq_len = amino_acid_indices.shape + seq_lengths = mask.sum(dim=1) + + # Forward direction - normal processing + forward_repr = self.represent_sequence( + amino_acid_indices, mask, + self.forward_amino_acid_embedding, + self.forward_pos_encoder, + self.forward_encoder + ) + + # Reverse direction - flip sequences and masks + reversed_indices = torch.zeros_like(amino_acid_indices) + reversed_mask = torch.zeros_like(mask) + + for i in range(batch_size): + length = seq_lengths[i] + # Reverse and left-pad the sequence + reversed_indices[i, -length:] = amino_acid_indices[i, :length].flip(0) + reversed_mask[i, -length:] = mask[i, :length].flip(0) + + reverse_repr = self.represent_sequence( + reversed_indices, reversed_mask, + self.reverse_amino_acid_embedding, + self.reverse_pos_encoder, + self.reverse_encoder + ) + + # Un-reverse the representations to align with forward direction + aligned_reverse_repr = torch.zeros_like(reverse_repr) + for i in range(batch_size): + length = seq_lengths[i] + aligned_reverse_repr[i, :length] = reverse_repr[i, -length:].flip(0) + + # Combine features + combined = torch.cat([forward_repr, aligned_reverse_repr], dim=-1) + combined = self.combine_features(combined) + + # Output layer + return self.output(combined).squeeze(-1) + + @property + def hyperparameters(self): + return { + "nhead": self.nhead, + "d_model_per_head": self.d_model_per_head, + "dim_feedforward": self.dim_feedforward, + "layer_count": self.forward_encoder.num_layers, + "dropout_prob": self.forward_pos_encoder.dropout.p, + "output_dim": self.output.out_features, + "known_token_count": self.known_token_count, + } + class SingleValueBinarySelectionModel(AbstractBinarySelectionModel): """A one parameter selection model as a baseline.""" From 295ed37d9c6b126225721e2dce6794736ee32e55 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 14 Feb 2025 12:04:27 -0800 Subject: [PATCH 02/10] format --- netam/models.py | 58 +++++++++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/netam/models.py b/netam/models.py index f33002e6..1fe46883 100644 --- a/netam/models.py +++ b/netam/models.py @@ -814,30 +814,38 @@ def __init__( self.dim_feedforward = dim_feedforward # Forward direction components self.forward_pos_encoder = PositionalEncoding(self.d_model, dropout_prob) - self.forward_amino_acid_embedding = nn.Embedding(self.known_token_count, self.d_model) + self.forward_amino_acid_embedding = nn.Embedding( + self.known_token_count, self.d_model + ) self.forward_encoder_layer = nn.TransformerEncoderLayer( d_model=self.d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, ) - self.forward_encoder = nn.TransformerEncoder(self.forward_encoder_layer, layer_count) - + self.forward_encoder = nn.TransformerEncoder( + self.forward_encoder_layer, layer_count + ) + # Reverse direction components self.reverse_pos_encoder = PositionalEncoding(self.d_model, dropout_prob) - self.reverse_amino_acid_embedding = nn.Embedding(self.known_token_count, self.d_model) + self.reverse_amino_acid_embedding = nn.Embedding( + self.known_token_count, self.d_model + ) self.reverse_encoder_layer = nn.TransformerEncoderLayer( d_model=self.d_model, nhead=nhead, dim_feedforward=dim_feedforward, batch_first=True, ) - self.reverse_encoder = nn.TransformerEncoder(self.reverse_encoder_layer, layer_count) - + self.reverse_encoder = nn.TransformerEncoder( + self.reverse_encoder_layer, layer_count + ) + # Output layers self.combine_features = nn.Linear(2 * self.d_model, self.d_model) self.output = nn.Linear(self.d_model, output_dim) - + self.init_weights() def init_weights(self) -> None: @@ -847,9 +855,14 @@ def init_weights(self) -> None: self.output.bias.data.zero_() self.output.weight.data.uniform_(-initrange, initrange) - def represent_sequence(self, indices: Tensor, mask: Tensor, - embedding: nn.Embedding, pos_encoder: PositionalEncoding, - encoder: nn.TransformerEncoder) -> Tensor: + def represent_sequence( + self, + indices: Tensor, + mask: Tensor, + embedding: nn.Embedding, + pos_encoder: PositionalEncoding, + encoder: nn.TransformerEncoder, + ) -> Tensor: """Process sequence through one direction of the model.""" embedded = embedding(indices) * math.sqrt(self.d_model) embedded = pos_encoder(embedded.permute(1, 0, 2)).permute(1, 0, 2) @@ -858,42 +871,44 @@ def represent_sequence(self, indices: Tensor, mask: Tensor, def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: batch_size, seq_len = amino_acid_indices.shape seq_lengths = mask.sum(dim=1) - + # Forward direction - normal processing forward_repr = self.represent_sequence( - amino_acid_indices, mask, + amino_acid_indices, + mask, self.forward_amino_acid_embedding, self.forward_pos_encoder, - self.forward_encoder + self.forward_encoder, ) - + # Reverse direction - flip sequences and masks reversed_indices = torch.zeros_like(amino_acid_indices) reversed_mask = torch.zeros_like(mask) - + for i in range(batch_size): length = seq_lengths[i] # Reverse and left-pad the sequence reversed_indices[i, -length:] = amino_acid_indices[i, :length].flip(0) reversed_mask[i, -length:] = mask[i, :length].flip(0) - + reverse_repr = self.represent_sequence( - reversed_indices, reversed_mask, + reversed_indices, + reversed_mask, self.reverse_amino_acid_embedding, self.reverse_pos_encoder, - self.reverse_encoder + self.reverse_encoder, ) - + # Un-reverse the representations to align with forward direction aligned_reverse_repr = torch.zeros_like(reverse_repr) for i in range(batch_size): length = seq_lengths[i] aligned_reverse_repr[i, :length] = reverse_repr[i, -length:].flip(0) - + # Combine features combined = torch.cat([forward_repr, aligned_reverse_repr], dim=-1) combined = self.combine_features(combined) - + # Output layer return self.output(combined).squeeze(-1) @@ -909,6 +924,7 @@ def hyperparameters(self): "known_token_count": self.known_token_count, } + class SingleValueBinarySelectionModel(AbstractBinarySelectionModel): """A one parameter selection model as a baseline.""" From 3346249894bfef9cd4620958d6f405f0b16de8a5 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 14 Feb 2025 16:16:35 -0800 Subject: [PATCH 03/10] fix claude mistake, refactor, add symmetric model just for fun --- netam/models.py | 82 +++++++++++++++++++++++++++++++++++++------------ 1 file changed, 62 insertions(+), 20 deletions(-) diff --git a/netam/models.py b/netam/models.py index 1fe46883..a0b2e9ab 100644 --- a/netam/models.py +++ b/netam/models.py @@ -23,6 +23,7 @@ aa_idx_tensor_of_str_ambig, PositionalEncoding, split_heavy_light_model_outputs, + AA_AMBIG_IDX, ) from typing import Tuple @@ -795,6 +796,49 @@ def predict(self, representation: Tensor): return wiggle(super().predict(representation), beta) +# TODO it's bad practice to hard code the AA_AMBIG_IDX as the padding +# value here. +def reverse_padded_seqs_and_mask(amino_acid_indices, mask, seq_lengths): + batch_size, _seq_len = amino_acid_indices.shape + reversed_indices = torch.full_like(amino_acid_indices, AA_AMBIG_IDX) + reversed_mask = torch.zeros_like(mask) + + for i in range(batch_size): + length = seq_lengths[i] + # Reverse and left-pad the sequence + reversed_indices[i, :length] = amino_acid_indices[i, :length].flip(0) + reversed_mask[i, :length] = mask[i, :length].flip(0) + return reversed_indices, reversed_mask + +def reverse_padded_output(reverse_repr, seq_lengths): + # Un-reverse the representations to align with forward direction + # TODO it may not matter, but I don't think the masked outputs are + # necessarily zero. + batch_size, _seq_len, _d_model = reverse_repr.shape + aligned_reverse_repr = torch.zeros_like(reverse_repr) + for i in range(batch_size): + length = seq_lengths[i] + aligned_reverse_repr[i, :length] = reverse_repr[i, :length].flip(0) + return aligned_reverse_repr + + +class SymmetricTransformerBinarySelectionModelLinAct(TransformerBinarySelectionModelLinAct): + def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: + seq_lengths = mask.sum(dim=1) + reversed_indices, reversed_mask = reverse_padded_seqs_and_mask( + amino_acid_indices, mask, seq_lengths + ) + reversed_outputs = super().represent(reversed_indices, reversed_mask) + aligned_reverse_outputs = reverse_padded_output(reversed_outputs, seq_lengths) + outputs = super().represent(amino_acid_indices, mask) + return (outputs + aligned_reverse_outputs) / 2 + +class SymmetricTransformerBinarySelectionModelWiggleAct(SymmetricTransformerBinarySelectionModelLinAct): + """Here the beta parameter is fixed at 0.3.""" + + def predict(self, representation: Tensor): + return wiggle(super().predict(representation), 0.3) + class BidirectionalTransformerBinarySelectionModel(AbstractBinarySelectionModel): def __init__( self, @@ -855,7 +899,7 @@ def init_weights(self) -> None: self.output.bias.data.zero_() self.output.weight.data.uniform_(-initrange, initrange) - def represent_sequence( + def single_direction_represent_sequence( self, indices: Tensor, mask: Tensor, @@ -868,12 +912,14 @@ def represent_sequence( embedded = pos_encoder(embedded.permute(1, 0, 2)).permute(1, 0, 2) return encoder(embedded, src_key_padding_mask=~mask) - def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: - batch_size, seq_len = amino_acid_indices.shape + def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: + batch_size, _seq_len = amino_acid_indices.shape + # This is okay, as long as there are no masked ambiguities in the + # interior of the sequence... Otherwise it should work for paired seqs. seq_lengths = mask.sum(dim=1) # Forward direction - normal processing - forward_repr = self.represent_sequence( + forward_repr = self.single_direction_represent_sequence( amino_acid_indices, mask, self.forward_amino_acid_embedding, @@ -882,16 +928,11 @@ def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: ) # Reverse direction - flip sequences and masks - reversed_indices = torch.zeros_like(amino_acid_indices) - reversed_mask = torch.zeros_like(mask) - - for i in range(batch_size): - length = seq_lengths[i] - # Reverse and left-pad the sequence - reversed_indices[i, -length:] = amino_acid_indices[i, :length].flip(0) - reversed_mask[i, -length:] = mask[i, :length].flip(0) + reverse_indices, reversed_mask = reverse_padded_seqs_and_mask( + amino_acid_indices, mask, seq_lengths + ) - reverse_repr = self.represent_sequence( + reverse_repr = self.single_direction_represent_sequence( reversed_indices, reversed_mask, self.reverse_amino_acid_embedding, @@ -899,18 +940,19 @@ def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: self.reverse_encoder, ) - # Un-reverse the representations to align with forward direction - aligned_reverse_repr = torch.zeros_like(reverse_repr) - for i in range(batch_size): - length = seq_lengths[i] - aligned_reverse_repr[i, :length] = reverse_repr[i, -length:].flip(0) + aligned_reverse_repr = reverse_padded_output(reverse_repr, seq_lengths) # Combine features combined = torch.cat([forward_repr, aligned_reverse_repr], dim=-1) - combined = self.combine_features(combined) + return self.combine_features(combined) + def predict(self, representation: Tensor) -> Tensor: # Output layer - return self.output(combined).squeeze(-1) + return self.output(representation).squeeze(-1) + + def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: + return self.predict(self.represent(amino_acid_indices, mask)) + @property def hyperparameters(self): From eca9194066cb643fb003f4773d57938b4637efab Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 14 Feb 2025 16:16:45 -0800 Subject: [PATCH 04/10] format --- netam/models.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/netam/models.py b/netam/models.py index a0b2e9ab..05a66443 100644 --- a/netam/models.py +++ b/netam/models.py @@ -810,6 +810,7 @@ def reverse_padded_seqs_and_mask(amino_acid_indices, mask, seq_lengths): reversed_mask[i, :length] = mask[i, :length].flip(0) return reversed_indices, reversed_mask + def reverse_padded_output(reverse_repr, seq_lengths): # Un-reverse the representations to align with forward direction # TODO it may not matter, but I don't think the masked outputs are @@ -822,7 +823,9 @@ def reverse_padded_output(reverse_repr, seq_lengths): return aligned_reverse_repr -class SymmetricTransformerBinarySelectionModelLinAct(TransformerBinarySelectionModelLinAct): +class SymmetricTransformerBinarySelectionModelLinAct( + TransformerBinarySelectionModelLinAct +): def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: seq_lengths = mask.sum(dim=1) reversed_indices, reversed_mask = reverse_padded_seqs_and_mask( @@ -833,12 +836,16 @@ def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: outputs = super().represent(amino_acid_indices, mask) return (outputs + aligned_reverse_outputs) / 2 -class SymmetricTransformerBinarySelectionModelWiggleAct(SymmetricTransformerBinarySelectionModelLinAct): + +class SymmetricTransformerBinarySelectionModelWiggleAct( + SymmetricTransformerBinarySelectionModelLinAct +): """Here the beta parameter is fixed at 0.3.""" def predict(self, representation: Tensor): return wiggle(super().predict(representation), 0.3) + class BidirectionalTransformerBinarySelectionModel(AbstractBinarySelectionModel): def __init__( self, @@ -953,7 +960,6 @@ def predict(self, representation: Tensor) -> Tensor: def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: return self.predict(self.represent(amino_acid_indices, mask)) - @property def hyperparameters(self): return { From e89a07d381602287c9dec6ab60c7c5e5870ab25c Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 14 Feb 2025 20:51:23 -0800 Subject: [PATCH 05/10] docstrings --- netam/models.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/netam/models.py b/netam/models.py index 05a66443..1bb152da 100644 --- a/netam/models.py +++ b/netam/models.py @@ -799,6 +799,18 @@ def predict(self, representation: Tensor): # TODO it's bad practice to hard code the AA_AMBIG_IDX as the padding # value here. def reverse_padded_seqs_and_mask(amino_acid_indices, mask, seq_lengths): + """ + Reverse the provided left-aligned amino acid sequences and masks, + but move the padding to the right of the reversed sequence. + Equivalent to right-aligning the sequences then reversing them. + + Args: + amino_acid_indices: (B, L) tensor of amino acid indices + mask: (B, L) tensor of masks + seq_lengths: (B,) tensor of sequence lengths + Returns: + A tuple of the reversed amino acid indices and mask. + """ batch_size, _seq_len = amino_acid_indices.shape reversed_indices = torch.full_like(amino_acid_indices, AA_AMBIG_IDX) reversed_mask = torch.zeros_like(mask) @@ -811,10 +823,11 @@ def reverse_padded_seqs_and_mask(amino_acid_indices, mask, seq_lengths): return reversed_indices, reversed_mask +# TODO it may not matter, but I don't think the masked outputs are +# necessarily zero. def reverse_padded_output(reverse_repr, seq_lengths): - # Un-reverse the representations to align with forward direction - # TODO it may not matter, but I don't think the masked outputs are - # necessarily zero. + """Companion to `reverse_padded_seqs_and_mask` that reverses a model's representation + so that it aligns with the forward direction of that function's input sequence.""" batch_size, _seq_len, _d_model = reverse_repr.shape aligned_reverse_repr = torch.zeros_like(reverse_repr) for i in range(batch_size): @@ -947,6 +960,7 @@ def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: self.reverse_encoder, ) + # un-reverse to align with forward representation aligned_reverse_repr = reverse_padded_output(reverse_repr, seq_lengths) # Combine features From 1e11e7b58db2e1933559ef8e41245b48e92e869f Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Fri, 14 Feb 2025 20:51:34 -0800 Subject: [PATCH 06/10] format --- netam/models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/netam/models.py b/netam/models.py index 1bb152da..36d1193c 100644 --- a/netam/models.py +++ b/netam/models.py @@ -799,10 +799,9 @@ def predict(self, representation: Tensor): # TODO it's bad practice to hard code the AA_AMBIG_IDX as the padding # value here. def reverse_padded_seqs_and_mask(amino_acid_indices, mask, seq_lengths): - """ - Reverse the provided left-aligned amino acid sequences and masks, - but move the padding to the right of the reversed sequence. - Equivalent to right-aligning the sequences then reversing them. + """Reverse the provided left-aligned amino acid sequences and masks, but move the + padding to the right of the reversed sequence. Equivalent to right-aligning the + sequences then reversing them. Args: amino_acid_indices: (B, L) tensor of amino acid indices @@ -826,8 +825,9 @@ def reverse_padded_seqs_and_mask(amino_acid_indices, mask, seq_lengths): # TODO it may not matter, but I don't think the masked outputs are # necessarily zero. def reverse_padded_output(reverse_repr, seq_lengths): - """Companion to `reverse_padded_seqs_and_mask` that reverses a model's representation - so that it aligns with the forward direction of that function's input sequence.""" + """Companion to `reverse_padded_seqs_and_mask` that reverses a model's + representation so that it aligns with the forward direction of that function's input + sequence.""" batch_size, _seq_len, _d_model = reverse_repr.shape aligned_reverse_repr = torch.zeros_like(reverse_repr) for i in range(batch_size): From 7f209eecae6f337a53fd8dc2cc7be39e1a8194f1 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Thu, 20 Feb 2025 11:04:01 -0800 Subject: [PATCH 07/10] Improve sequence reversal --- netam/models.py | 74 ++++++++++++++++++++------------------------- tests/test_netam.py | 36 +++++++++++++++++++++- 2 files changed, 68 insertions(+), 42 deletions(-) diff --git a/netam/models.py b/netam/models.py index 36d1193c..d1fad1e4 100644 --- a/netam/models.py +++ b/netam/models.py @@ -798,54 +798,46 @@ def predict(self, representation: Tensor): # TODO it's bad practice to hard code the AA_AMBIG_IDX as the padding # value here. -def reverse_padded_seqs_and_mask(amino_acid_indices, mask, seq_lengths): - """Reverse the provided left-aligned amino acid sequences and masks, but move the - padding to the right of the reversed sequence. Equivalent to right-aligning the - sequences then reversing them. +def reverse_padded_tensors(padded_tensors, padding_mask, padding_value, reversed_dim=1): + """Reverse the valid values in provided padded_tensors along the specified + dimension, keeping padding in the same place. For example, if the input is left- + aligned amino acid sequences and masks, move the padding to the right of the + reversed sequence. Equivalent to right-aligning the sequences then reversing them. A + sequence `123456XXXXX` becomes `654321XXXXX`. + + The original padding mask remains valid for the returned tensor. Args: - amino_acid_indices: (B, L) tensor of amino acid indices - mask: (B, L) tensor of masks - seq_lengths: (B,) tensor of sequence lengths + padded_tensors: (B, L) tensor of amino acid indices + padding_mask: (B, L) tensor of masks, with True indicating valid values, and False indicating padding values. + padding_value: The value to fill returned tensor where padding_mask is False. + reversed_dim: The dimension along which to reverse the tensor. When input is a batch of sequences to be reversed, the default value of 1 is the correct choice. Returns: - A tuple of the reversed amino acid indices and mask. + The reversed tensor, with the same shape as padded_tensors, and with padding still specified by padding_mask. """ - batch_size, _seq_len = amino_acid_indices.shape - reversed_indices = torch.full_like(amino_acid_indices, AA_AMBIG_IDX) - reversed_mask = torch.zeros_like(mask) - - for i in range(batch_size): - length = seq_lengths[i] - # Reverse and left-pad the sequence - reversed_indices[i, :length] = amino_acid_indices[i, :length].flip(0) - reversed_mask[i, :length] = mask[i, :length].flip(0) - return reversed_indices, reversed_mask - - -# TODO it may not matter, but I don't think the masked outputs are -# necessarily zero. -def reverse_padded_output(reverse_repr, seq_lengths): - """Companion to `reverse_padded_seqs_and_mask` that reverses a model's - representation so that it aligns with the forward direction of that function's input - sequence.""" - batch_size, _seq_len, _d_model = reverse_repr.shape - aligned_reverse_repr = torch.zeros_like(reverse_repr) - for i in range(batch_size): - length = seq_lengths[i] - aligned_reverse_repr[i, :length] = reverse_repr[i, :length].flip(0) - return aligned_reverse_repr + reversed_indices = torch.full_like(padded_tensors, padding_value) + reversed_indices[padding_mask] = padded_tensors.flip(reversed_dim)[ + padding_mask.flip(reversed_dim) + ] + return reversed_indices class SymmetricTransformerBinarySelectionModelLinAct( TransformerBinarySelectionModelLinAct ): def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: - seq_lengths = mask.sum(dim=1) - reversed_indices, reversed_mask = reverse_padded_seqs_and_mask( - amino_acid_indices, mask, seq_lengths + reversed_indices = reverse_padded_tensors( + amino_acid_indices, mask, AA_AMBIG_IDX ) - reversed_outputs = super().represent(reversed_indices, reversed_mask) - aligned_reverse_outputs = reverse_padded_output(reversed_outputs, seq_lengths) + # This assumes that the mask is True on all sites in the interior of + # the sequence, and False on padding. This assumption is not met for + # sequences with masked ambiguities in the interior. + # We convert a padded sequence `123456XXXXX` to `654321XXXXX`, so the + # mask does not need to be reversed. + reversed_outputs = super().represent(reversed_indices, mask) + # TODO it may not matter, but I don't think the masked outputs are + # necessarily zero. + aligned_reverse_outputs = reverse_padded_tensors(reversed_outputs, mask, 0.0) outputs = super().represent(amino_acid_indices, mask) return (outputs + aligned_reverse_outputs) / 2 @@ -948,20 +940,20 @@ def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: ) # Reverse direction - flip sequences and masks - reverse_indices, reversed_mask = reverse_padded_seqs_and_mask( - amino_acid_indices, mask, seq_lengths + reversed_indices = reverse_padded_tensors( + amino_acid_indices, mask, AA_AMBIG_IDX ) reverse_repr = self.single_direction_represent_sequence( reversed_indices, - reversed_mask, + mask, self.reverse_amino_acid_embedding, self.reverse_pos_encoder, self.reverse_encoder, ) # un-reverse to align with forward representation - aligned_reverse_repr = reverse_padded_output(reverse_repr, seq_lengths) + aligned_reverse_repr = reverse_padded_tensors(reverse_repr, mask, 0.0) # Combine features combined = torch.cat([forward_repr, aligned_reverse_repr], dim=-1) diff --git a/tests/test_netam.py b/tests/test_netam.py index c2fe9088..c0faed11 100644 --- a/tests/test_netam.py +++ b/tests/test_netam.py @@ -8,7 +8,12 @@ import netam.framework as framework from netam.common import BIG from netam.framework import SHMoofDataset, SHMBurrito, RSSHMBurrito -from netam.models import SHMoofModel, RSSHMoofModel, IndepRSCNNModel +from netam.models import ( + SHMoofModel, + RSSHMoofModel, + IndepRSCNNModel, + reverse_padded_tensors, +) @pytest.fixture @@ -114,3 +119,32 @@ def test_standardize_model_rates(mini_rsburrito): mini_rsburrito.standardize_model_rates() vrc01_rate_14 = mini_rsburrito.vrc01_site_14_model_rate() assert np.isclose(vrc01_rate_14, 1.0) + + +def test_reverse_padded_tensors(): + # Here we just test that we can apply the function twice and get the + # original input back. + test_tensor = torch.tensor( + [ + [1, 2, 3, 4, 0, 0], + [1, 2, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [1, 2, 3, 4, 5, 0], + [1, 2, 3, 4, 5, 6], + [1, 2, 0, 0, 0, 0], + ] + ) + true_reversed = torch.tensor( + [ + [4, 3, 2, 1, 0, 0], + [2, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [5, 4, 3, 2, 1, 0], + [6, 5, 4, 3, 2, 1], + [2, 1, 0, 0, 0, 0], + ] + ) + mask = test_tensor > 0 + reversed_tensor = reverse_padded_tensors(test_tensor, mask, 0) + assert torch.equal(true_reversed, reversed_tensor) + assert torch.equal(test_tensor, reverse_padded_tensors(reversed_tensor, mask, 0)) From 2054949f1c3babe823d9894d264fea5338740ae3 Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Mon, 24 Feb 2025 14:42:16 -0800 Subject: [PATCH 08/10] drop sym and add wiggle to bidirectional model --- netam/dxsm.py | 4 ++-- netam/models.py | 42 +++++++++--------------------------------- netam/sequences.py | 3 +++ 3 files changed, 14 insertions(+), 35 deletions(-) diff --git a/netam/dxsm.py b/netam/dxsm.py index 4ac7a7a4..ebfbf7c0 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -32,7 +32,7 @@ token_mask_of_aa_idxs, MAX_AA_TOKEN_IDX, RESERVED_TOKEN_REGEX, - AA_AMBIG_IDX, + AA_PADDING_TOKEN, ) @@ -134,7 +134,7 @@ def of_seriess( # We have sequences of varying length, so we start with all tensors set # to the ambiguous amino acid, and then will fill in the actual values # below. - aa_parents_idxss = torch.full((pcp_count, max_aa_seq_len), AA_AMBIG_IDX) + aa_parents_idxss = torch.full((pcp_count, max_aa_seq_len), AA_PADDING_TOKEN) aa_children_idxss = aa_parents_idxss.clone() aa_subs_indicators = torch.zeros((pcp_count, max_aa_seq_len)) diff --git a/netam/models.py b/netam/models.py index d1fad1e4..04b2cb43 100644 --- a/netam/models.py +++ b/netam/models.py @@ -23,7 +23,7 @@ aa_idx_tensor_of_str_ambig, PositionalEncoding, split_heavy_light_model_outputs, - AA_AMBIG_IDX, + AA_PADDING_TOKEN, ) from typing import Tuple @@ -796,8 +796,6 @@ def predict(self, representation: Tensor): return wiggle(super().predict(representation), beta) -# TODO it's bad practice to hard code the AA_AMBIG_IDX as the padding -# value here. def reverse_padded_tensors(padded_tensors, padding_mask, padding_value, reversed_dim=1): """Reverse the valid values in provided padded_tensors along the specified dimension, keeping padding in the same place. For example, if the input is left- @@ -822,35 +820,6 @@ def reverse_padded_tensors(padded_tensors, padding_mask, padding_value, reversed return reversed_indices -class SymmetricTransformerBinarySelectionModelLinAct( - TransformerBinarySelectionModelLinAct -): - def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: - reversed_indices = reverse_padded_tensors( - amino_acid_indices, mask, AA_AMBIG_IDX - ) - # This assumes that the mask is True on all sites in the interior of - # the sequence, and False on padding. This assumption is not met for - # sequences with masked ambiguities in the interior. - # We convert a padded sequence `123456XXXXX` to `654321XXXXX`, so the - # mask does not need to be reversed. - reversed_outputs = super().represent(reversed_indices, mask) - # TODO it may not matter, but I don't think the masked outputs are - # necessarily zero. - aligned_reverse_outputs = reverse_padded_tensors(reversed_outputs, mask, 0.0) - outputs = super().represent(amino_acid_indices, mask) - return (outputs + aligned_reverse_outputs) / 2 - - -class SymmetricTransformerBinarySelectionModelWiggleAct( - SymmetricTransformerBinarySelectionModelLinAct -): - """Here the beta parameter is fixed at 0.3.""" - - def predict(self, representation: Tensor): - return wiggle(super().predict(representation), 0.3) - - class BidirectionalTransformerBinarySelectionModel(AbstractBinarySelectionModel): def __init__( self, @@ -941,7 +910,7 @@ def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: # Reverse direction - flip sequences and masks reversed_indices = reverse_padded_tensors( - amino_acid_indices, mask, AA_AMBIG_IDX + amino_acid_indices, mask, AA_PADDING_TOKEN ) reverse_repr = self.single_direction_represent_sequence( @@ -979,6 +948,13 @@ def hyperparameters(self): } +class BidirectionalTransformerBinarySelectionModelWiggleAct(BidirectionalTransformerBinarySelectionModel): + """Here the beta parameter is fixed at 0.3.""" + + def predict(self, representation: Tensor): + return wiggle(super().predict(representation), 0.3) + + class SingleValueBinarySelectionModel(AbstractBinarySelectionModel): """A one parameter selection model as a baseline.""" diff --git a/netam/sequences.py b/netam/sequences.py index 01d99179..9263d68b 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -22,6 +22,9 @@ NT_STR_SORTED = "".join(BASES) BASES_AND_N_TO_INDEX = {base: idx for idx, base in enumerate(NT_STR_SORTED + "N")} AA_AMBIG_IDX = len(AA_STR_SORTED) +# Used for padding amino acid sequences to the same length. Differentiated by +# name in case we add a padding token other than AA_AMBIG_IDX later. +AA_PADDING_TOKEN = AA_AMBIG_IDX CODONS = ["".join(codon_list) for codon_list in itertools.product(BASES, repeat=3)] STOP_CODONS = ["TAA", "TAG", "TGA"] From 2ec1ef09fb969e0a92fdbda1feb3acb9d3a7daea Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Mon, 24 Feb 2025 14:42:30 -0800 Subject: [PATCH 09/10] format --- netam/models.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/netam/models.py b/netam/models.py index 04b2cb43..d5de4fc4 100644 --- a/netam/models.py +++ b/netam/models.py @@ -948,7 +948,9 @@ def hyperparameters(self): } -class BidirectionalTransformerBinarySelectionModelWiggleAct(BidirectionalTransformerBinarySelectionModel): +class BidirectionalTransformerBinarySelectionModelWiggleAct( + BidirectionalTransformerBinarySelectionModel +): """Here the beta parameter is fixed at 0.3.""" def predict(self, representation: Tensor): From 864b24755f312df0e78d7bcf0e64b38de3afca8f Mon Sep 17 00:00:00 2001 From: Will Dumm Date: Mon, 24 Feb 2025 14:43:56 -0800 Subject: [PATCH 10/10] lint --- netam/models.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/netam/models.py b/netam/models.py index d5de4fc4..ee43e019 100644 --- a/netam/models.py +++ b/netam/models.py @@ -894,10 +894,8 @@ def single_direction_represent_sequence( return encoder(embedded, src_key_padding_mask=~mask) def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor: - batch_size, _seq_len = amino_acid_indices.shape # This is okay, as long as there are no masked ambiguities in the - # interior of the sequence... Otherwise it should work for paired seqs. - seq_lengths = mask.sum(dim=1) + # interior of the sequence... Otherwise it should also work for paired seqs. # Forward direction - normal processing forward_repr = self.single_direction_represent_sequence(