Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add bidirectional model #118

Merged
merged 10 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
token_mask_of_aa_idxs,
MAX_AA_TOKEN_IDX,
RESERVED_TOKEN_REGEX,
AA_AMBIG_IDX,
AA_PADDING_TOKEN,
)


Expand Down Expand Up @@ -134,7 +134,7 @@ def of_seriess(
# We have sequences of varying length, so we start with all tensors set
# to the ambiguous amino acid, and then will fill in the actual values
# below.
aa_parents_idxss = torch.full((pcp_count, max_aa_seq_len), AA_AMBIG_IDX)
aa_parents_idxss = torch.full((pcp_count, max_aa_seq_len), AA_PADDING_TOKEN)
aa_children_idxss = aa_parents_idxss.clone()
aa_subs_indicators = torch.zeros((pcp_count, max_aa_seq_len))

Expand Down
160 changes: 160 additions & 0 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
aa_idx_tensor_of_str_ambig,
PositionalEncoding,
split_heavy_light_model_outputs,
AA_PADDING_TOKEN,
)

from typing import Tuple
Expand Down Expand Up @@ -795,6 +796,165 @@ def predict(self, representation: Tensor):
return wiggle(super().predict(representation), beta)


def reverse_padded_tensors(padded_tensors, padding_mask, padding_value, reversed_dim=1):
"""Reverse the valid values in provided padded_tensors along the specified
dimension, keeping padding in the same place. For example, if the input is left-
aligned amino acid sequences and masks, move the padding to the right of the
reversed sequence. Equivalent to right-aligning the sequences then reversing them. A
sequence `123456XXXXX` becomes `654321XXXXX`.

The original padding mask remains valid for the returned tensor.

Args:
padded_tensors: (B, L) tensor of amino acid indices
padding_mask: (B, L) tensor of masks, with True indicating valid values, and False indicating padding values.
padding_value: The value to fill returned tensor where padding_mask is False.
reversed_dim: The dimension along which to reverse the tensor. When input is a batch of sequences to be reversed, the default value of 1 is the correct choice.
Returns:
The reversed tensor, with the same shape as padded_tensors, and with padding still specified by padding_mask.
"""
reversed_indices = torch.full_like(padded_tensors, padding_value)
reversed_indices[padding_mask] = padded_tensors.flip(reversed_dim)[
padding_mask.flip(reversed_dim)
]
return reversed_indices


class BidirectionalTransformerBinarySelectionModel(AbstractBinarySelectionModel):
def __init__(
self,
nhead: int,
d_model_per_head: int,
dim_feedforward: int,
layer_count: int,
dropout_prob: float = 0.5,
output_dim: int = 1,
known_token_count: int = MAX_AA_TOKEN_IDX + 1,
):
super().__init__()
self.known_token_count = known_token_count
self.d_model_per_head = d_model_per_head
self.d_model = d_model_per_head * nhead
self.nhead = nhead
self.dim_feedforward = dim_feedforward
# Forward direction components
self.forward_pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
self.forward_amino_acid_embedding = nn.Embedding(
self.known_token_count, self.d_model
)
self.forward_encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True,
)
self.forward_encoder = nn.TransformerEncoder(
self.forward_encoder_layer, layer_count
)

# Reverse direction components
self.reverse_pos_encoder = PositionalEncoding(self.d_model, dropout_prob)
self.reverse_amino_acid_embedding = nn.Embedding(
self.known_token_count, self.d_model
)
self.reverse_encoder_layer = nn.TransformerEncoderLayer(
d_model=self.d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True,
)
self.reverse_encoder = nn.TransformerEncoder(
self.reverse_encoder_layer, layer_count
)

# Output layers
self.combine_features = nn.Linear(2 * self.d_model, self.d_model)
self.output = nn.Linear(self.d_model, output_dim)

self.init_weights()

def init_weights(self) -> None:
initrange = 0.1
self.combine_features.bias.data.zero_()
self.combine_features.weight.data.uniform_(-initrange, initrange)
self.output.bias.data.zero_()
self.output.weight.data.uniform_(-initrange, initrange)

def single_direction_represent_sequence(
self,
indices: Tensor,
mask: Tensor,
embedding: nn.Embedding,
pos_encoder: PositionalEncoding,
encoder: nn.TransformerEncoder,
) -> Tensor:
"""Process sequence through one direction of the model."""
embedded = embedding(indices) * math.sqrt(self.d_model)
embedded = pos_encoder(embedded.permute(1, 0, 2)).permute(1, 0, 2)
return encoder(embedded, src_key_padding_mask=~mask)

def represent(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
# This is okay, as long as there are no masked ambiguities in the
# interior of the sequence... Otherwise it should also work for paired seqs.

# Forward direction - normal processing
forward_repr = self.single_direction_represent_sequence(
amino_acid_indices,
mask,
self.forward_amino_acid_embedding,
self.forward_pos_encoder,
self.forward_encoder,
)

# Reverse direction - flip sequences and masks
reversed_indices = reverse_padded_tensors(
amino_acid_indices, mask, AA_PADDING_TOKEN
)

reverse_repr = self.single_direction_represent_sequence(
reversed_indices,
mask,
self.reverse_amino_acid_embedding,
self.reverse_pos_encoder,
self.reverse_encoder,
)

# un-reverse to align with forward representation
aligned_reverse_repr = reverse_padded_tensors(reverse_repr, mask, 0.0)

# Combine features
combined = torch.cat([forward_repr, aligned_reverse_repr], dim=-1)
return self.combine_features(combined)

def predict(self, representation: Tensor) -> Tensor:
# Output layer
return self.output(representation).squeeze(-1)

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
return self.predict(self.represent(amino_acid_indices, mask))

@property
def hyperparameters(self):
return {
"nhead": self.nhead,
"d_model_per_head": self.d_model_per_head,
"dim_feedforward": self.dim_feedforward,
"layer_count": self.forward_encoder.num_layers,
"dropout_prob": self.forward_pos_encoder.dropout.p,
"output_dim": self.output.out_features,
"known_token_count": self.known_token_count,
}


class BidirectionalTransformerBinarySelectionModelWiggleAct(
BidirectionalTransformerBinarySelectionModel
):
"""Here the beta parameter is fixed at 0.3."""

def predict(self, representation: Tensor):
return wiggle(super().predict(representation), 0.3)


class SingleValueBinarySelectionModel(AbstractBinarySelectionModel):
"""A one parameter selection model as a baseline."""

Expand Down
3 changes: 3 additions & 0 deletions netam/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
NT_STR_SORTED = "".join(BASES)
BASES_AND_N_TO_INDEX = {base: idx for idx, base in enumerate(NT_STR_SORTED + "N")}
AA_AMBIG_IDX = len(AA_STR_SORTED)
# Used for padding amino acid sequences to the same length. Differentiated by
# name in case we add a padding token other than AA_AMBIG_IDX later.
AA_PADDING_TOKEN = AA_AMBIG_IDX

CODONS = ["".join(codon_list) for codon_list in itertools.product(BASES, repeat=3)]
STOP_CODONS = ["TAA", "TAG", "TGA"]
Expand Down
36 changes: 35 additions & 1 deletion tests/test_netam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import netam.framework as framework
from netam.common import BIG
from netam.framework import SHMoofDataset, SHMBurrito, RSSHMBurrito
from netam.models import SHMoofModel, RSSHMoofModel, IndepRSCNNModel
from netam.models import (
SHMoofModel,
RSSHMoofModel,
IndepRSCNNModel,
reverse_padded_tensors,
)


@pytest.fixture
Expand Down Expand Up @@ -114,3 +119,32 @@ def test_standardize_model_rates(mini_rsburrito):
mini_rsburrito.standardize_model_rates()
vrc01_rate_14 = mini_rsburrito.vrc01_site_14_model_rate()
assert np.isclose(vrc01_rate_14, 1.0)


def test_reverse_padded_tensors():
# Here we just test that we can apply the function twice and get the
# original input back.
test_tensor = torch.tensor(
[
[1, 2, 3, 4, 0, 0],
[1, 2, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[1, 2, 3, 4, 5, 0],
[1, 2, 3, 4, 5, 6],
[1, 2, 0, 0, 0, 0],
]
)
true_reversed = torch.tensor(
[
[4, 3, 2, 1, 0, 0],
[2, 1, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0],
[5, 4, 3, 2, 1, 0],
[6, 5, 4, 3, 2, 1],
[2, 1, 0, 0, 0, 0],
]
)
mask = test_tensor > 0
reversed_tensor = reverse_padded_tensors(test_tensor, mask, 0)
assert torch.equal(true_reversed, reversed_tensor)
assert torch.equal(test_tensor, reverse_padded_tensors(reversed_tensor, mask, 0))