Skip to content

Commit

Permalink
address TODOs, format and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
willdumm committed Jan 28, 2025
1 parent c3b4abc commit 284c2ee
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 68 deletions.
16 changes: 0 additions & 16 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,22 +177,6 @@ def loss_of_batch(self, batch):
csp_loss = self.xent_loss(csp_pred, csp_targets)
return torch.stack([subs_pos_loss, csp_loss])

def build_selection_matrix_from_parent_aa(
self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor
):
"""Build a selection matrix from a single parent amino acid sequence. Inputs are
expected to be as prepared in the Dataset constructor.
Values at ambiguous sites are meaningless.
"""
with torch.no_grad():
per_aa_selection_factors = self.selection_factors_of_aa_idxs(
aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0)
).exp()
return zap_predictions_along_diagonal(
per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0
).squeeze(0)

# This is not used anywhere, except for in a few tests. Keeping it around
# for that reason.
def _build_selection_matrix_from_parent(self, parent: Tuple[str, str]):
Expand Down
44 changes: 10 additions & 34 deletions netam/dcsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,22 @@

import copy

import pandas as pd
import torch
import torch.nn.functional as F

from netam.common import (
clamp_probability,
BIG,
)
from netam.dxsm import DXSMDataset, DXSMBurrito, zap_predictions_along_diagonal
from netam.dxsm import DXSMDataset, DXSMBurrito
import netam.molevol as molevol

from netam.common import aa_idx_tensor_of_str_ambig
from netam.sequences import (
aa_idx_array_of_str,
aa_subs_indicator_tensor_of,
build_stop_codon_indicator_tensor,
nt_idx_tensor_of_str,
token_mask_of_aa_idxs,
translate_sequence,
translate_sequences,
codon_idx_tensor_of_str_ambig,
AA_AMBIG_IDX,
AMBIGUOUS_CODON_IDX,
CODON_AA_INDICATOR_MATRIX,
RESERVED_TOKEN_REGEX,
MAX_AA_TOKEN_IDX,
)


Expand Down Expand Up @@ -146,6 +136,8 @@ def __getitem__(self, idx):
}

def to(self, device):
self.aa_codon_indicator_matrix = self.aa_codon_indicator_matrix.to(device)
self.stop_codon_zapper = self.stop_codon_zapper.to(device)
self.codon_parents_idxss = self.codon_parents_idxss.to(device)
self.codon_children_idxss = self.codon_children_idxss.to(device)
self.aa_parents_idxss = self.aa_parents_idxss.to(device)
Expand All @@ -166,7 +158,10 @@ class DCSMBurrito(DXSMBurrito):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.xent_loss = torch.nn.CrossEntropyLoss()
self.stop_codon_zapper = build_stop_codon_indicator_tensor() * -BIG
self.stop_codon_zapper = (build_stop_codon_indicator_tensor() * -BIG).to(
self.device
)
self.aa_codon_indicator_matrix = CODON_AA_INDICATOR_MATRIX.to(self.device).T

def prediction_pair_of_batch(self, batch):
"""Get log neutral codon substitution probabilities and log selection factors
Expand Down Expand Up @@ -209,13 +204,11 @@ def predictions_of_batch(self, batch):
# of the parent codon. Namely, we need to set the parent codon to 1 -
# sum(children).

# This indicator lifts things up from aa land to codon land.
# TODO I guess we could store indicator in self and have everything move with a self.to(device) call.
indicator = CODON_AA_INDICATOR_MATRIX.to(self.device).T
# The aa_codon_indicator_matrix lifts things up from aa land to codon land.
log_preds = (
log_neutral_codon_probs
+ log_selection_factors @ indicator
+ self.stop_codon_zapper.to(self.device)
+ log_selection_factors @ self.aa_codon_indicator_matrix
+ self.stop_codon_zapper
)
assert torch.isnan(log_preds).sum() == 0

Expand Down Expand Up @@ -258,20 +251,3 @@ def loss_of_batch(self, batch):
codon_children_idxs = codon_children_idxs[mask]

return self.xent_loss(predictions, codon_children_idxs)

# TODO copied from dasm.py (updated for new organization from Will's PR)
def build_selection_matrix_from_parent_aa(
self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor
):
"""Build a selection matrix from a single parent amino acid sequence. Inputs are
expected to be as prepared in the Dataset constructor.
Values at ambiguous sites are meaningless.
"""
with torch.no_grad():
per_aa_selection_factors = self.selection_factors_of_aa_idxs(
aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0)
).exp()
return zap_predictions_along_diagonal(
per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0
).squeeze(0)
19 changes: 9 additions & 10 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,37 +56,36 @@ def update_neutral_probs(self):
mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
nt_csps = nt_csps[:parent_len, :]

# TODO singular/plural mismatch
neutral_aa_mut_prob = molevol.neutral_aa_mut_probs(
neutral_aa_mut_probs = molevol.neutral_aa_mut_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
nt_csps.reshape(-1, 3, 4),
multihit_model=multihit_model,
)

if not torch.isfinite(neutral_aa_mut_prob).all():
if not torch.isfinite(neutral_aa_mut_probs).all():
print(f"Found a non-finite neutral_aa_mut_prob")
print(f"nt_parent: {nt_parent}")
print(f"mask: {mask}")
print(f"nt_rates: {nt_rates}")
print(f"nt_csps: {nt_csps}")
print(f"branch_length: {branch_length}")
raise ValueError(
f"neutral_aa_mut_prob is not finite: {neutral_aa_mut_prob}"
f"neutral_aa_mut_prob is not finite: {neutral_aa_mut_probs}"
)

# Ensure that all values are positive before taking the log later
neutral_aa_mut_prob = clamp_probability(neutral_aa_mut_prob)
neutral_aa_mut_probs = clamp_probability(neutral_aa_mut_probs)

pad_len = self.max_aa_seq_len - neutral_aa_mut_prob.shape[0]
pad_len = self.max_aa_seq_len - neutral_aa_mut_probs.shape[0]
if pad_len > 0:
neutral_aa_mut_prob = F.pad(
neutral_aa_mut_prob, (0, pad_len), value=1e-8
neutral_aa_mut_probs = F.pad(
neutral_aa_mut_probs, (0, pad_len), value=1e-8
)
# Here we zero out masked positions.
neutral_aa_mut_prob *= mask
neutral_aa_mut_probs *= mask

neutral_aa_mut_prob_l.append(neutral_aa_mut_prob)
neutral_aa_mut_prob_l.append(neutral_aa_mut_probs)

# Note that our masked out positions will have a nan log probability,
# which will require us to handle them correctly downstream.
Expand Down
17 changes: 17 additions & 0 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,23 @@ def to_crepe(self):
encoder = framework.PlaceholderEncoder()
return framework.Crepe(encoder, self.model, training_hyperparameters)

# This is overridden in DNSMBurrito
def build_selection_matrix_from_parent_aa(
self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor
):
"""Build a selection matrix from a single parent amino acid sequence. Inputs are
expected to be as prepared in the Dataset constructor.
Values at ambiguous sites are meaningless.
"""
with torch.no_grad():
per_aa_selection_factors = self.selection_factors_of_aa_idxs(
aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0)
).exp()
return zap_predictions_along_diagonal(
per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0
).squeeze(0)

@abstractmethod
def loss_of_batch(self, batch):
pass
Expand Down
10 changes: 2 additions & 8 deletions tests/test_dcsm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
import os

import torch
import pytest

from netam.common import BIG, force_spawn
from netam.framework import (
crepe_exists,
load_crepe,
)
from netam.sequences import MAX_AA_TOKEN_IDX, MAX_KNOWN_TOKEN_COUNT
from netam.common import force_spawn
from netam.sequences import MAX_KNOWN_TOKEN_COUNT
from netam.models import TransformerBinarySelectionModelWiggleAct
from netam.dcsm import (
DCSMBurrito,
Expand Down

0 comments on commit 284c2ee

Please sign in to comment.