Skip to content

Commit

Permalink
adding the kitchen sink to dcsm
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jan 14, 2025
1 parent 0144b87 commit fddcee1
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions netam/dcsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from netam.common import aa_idx_tensor_of_str_ambig
from netam.sequences import (
aa_idx_array_of_str,
aa_subs_indicator_tensor_of,
nt_idx_tensor_of_str,
token_mask_of_aa_idxs,
translate_sequence,
Expand Down Expand Up @@ -113,6 +114,8 @@ def __init__(
# Important to use the unmodified versions of nt_parents and
# nt_children so they still contain special tokens.
aa_parents = translate_sequences(nt_parents)
aa_children = translate_sequences(nt_children)

self.max_codon_seq_len = max(len(seq) for seq in aa_parents)
# We have sequences of varying length, so we start with all tensors set
# to the ambiguous amino acid, and then will fill in the actual values
Expand All @@ -125,13 +128,18 @@ def __init__(
self.aa_parents_idxss = torch.full(
(pcp_count, self.max_codon_seq_len), MAX_AA_TOKEN_IDX
)
self.aa_children_idxss = torch.full(
(pcp_count, self.max_codon_seq_len), MAX_AA_TOKEN_IDX
)
# TODO here we are computing the subs indicators. This is handy for OE plots.
self.aa_subs_indicators = torch.zeros((pcp_count, self.max_codon_seq_len))

self.masks = torch.ones((pcp_count, self.max_codon_seq_len), dtype=torch.bool)

# We are using the modified nt_parents and nt_children here because we
# don't want any funky symbols in our codon indices.
for i, (nt_parent, nt_child, aa_parent) in enumerate(
zip(self.nt_parents, self.nt_children, aa_parents)
for i, (nt_parent, nt_child, aa_parent, aa_child) in enumerate(
zip(self.nt_parents, self.nt_children, aa_parents, aa_children)
):
self.masks[i, :] = codon_mask_tensor_of(
nt_parent, nt_child, aa_length=self.max_codon_seq_len
Expand All @@ -150,9 +158,16 @@ def __init__(
self.aa_parents_idxss[i, :codon_seq_len] = aa_idx_tensor_of_str_ambig(
aa_parent
)
self.aa_children_idxss[i, :codon_seq_len] = aa_idx_tensor_of_str_ambig(
aa_child
)
self.aa_subs_indicators[i, :codon_seq_len] = aa_subs_indicator_tensor_of(
aa_parent, aa_child
)

assert torch.all(self.masks.sum(dim=1) > 0)
assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX
assert torch.max(self.aa_children_idxss) <= MAX_AA_TOKEN_IDX
assert torch.max(self.codon_parents_idxss) <= AMBIGUOUS_CODON_IDX

self._branch_lengths = branch_lengths
Expand Down Expand Up @@ -232,6 +247,8 @@ def __getitem__(self, idx):
"codon_parents_idxs": self.codon_parents_idxss[idx],
"codon_children_idxs": self.codon_children_idxss[idx],
"aa_parents_idxs": self.aa_parents_idxss[idx],
"aa_children_idxs": self.aa_children_idxss[idx],
"subs_indicator": self.aa_subs_indicators[idx],
"mask": self.masks[idx],
"log_neutral_codon_probs": self.log_neutral_codon_probss[idx],
"nt_rates": self.nt_ratess[idx],
Expand All @@ -242,6 +259,8 @@ def to(self, device):
self.codon_parents_idxss = self.codon_parents_idxss.to(device)
self.codon_children_idxss = self.codon_children_idxss.to(device)
self.aa_parents_idxss = self.aa_parents_idxss.to(device)
self.aa_children_idxss = self.aa_children_idxss.to(device)
self.aa_subs_indicators = self.aa_subs_indicators.to(device)
self.masks = self.masks.to(device)
self.log_neutral_codon_probss = self.log_neutral_codon_probss.to(device)
self.nt_ratess = self.nt_ratess.to(device)
Expand Down

0 comments on commit fddcee1

Please sign in to comment.