Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacing normalize_sub_probs with a check; fixing consistency problems #77

Merged
merged 5 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@ class DASMDataset(dnsm.DNSMDataset):
def update_neutral_probs(self):
neutral_aa_probs_l = []

for nt_parent, mask, nt_rates, branch_length, nt_csps in zip(
for nt_parent, mask, nt_rates, nt_csps, branch_length in zip(
self.nt_parents,
self.mask,
self.masks,
self.nt_ratess,
self._branch_lengths,
self.nt_cspss,
self._branch_lengths,
):
mask = mask.to("cpu")
nt_rates = nt_rates.to("cpu")
Expand All @@ -41,14 +41,13 @@ def update_neutral_probs(self):
parent_len = len(nt_parent)

mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
normed_nt_csps = molevol.normalize_sub_probs(
parent_idxs, nt_csps[:parent_len, :]
)
nt_csps = nt_csps[:parent_len, :]
molevol.check_csps(parent_idxs, nt_csps)

neutral_aa_probs = molevol.neutral_aa_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
normed_nt_csps.reshape(-1, 3, 4),
nt_csps.reshape(-1, 3, 4),
)

if not torch.isfinite(neutral_aa_probs).all():
Expand All @@ -75,25 +74,25 @@ def update_neutral_probs(self):

# Note that our masked out positions will have a nan log probability,
# which will require us to handle them correctly downstream.
self.log_neutral_aa_probs = torch.log(torch.stack(neutral_aa_probs_l))
self.log_neutral_aa_probss = torch.log(torch.stack(neutral_aa_probs_l))

def __getitem__(self, idx):
return {
"aa_parents_idxs": self.aa_parents_idxs[idx],
"aa_children_idxs": self.aa_children_idxs[idx],
"subs_indicator": self.aa_subs_indicator_tensor[idx],
"mask": self.mask[idx],
"log_neutral_aa_probs": self.log_neutral_aa_probs[idx],
"aa_parents_idxs": self.aa_parents_idxss[idx],
"aa_children_idxs": self.aa_children_idxss[idx],
"subs_indicator": self.aa_subs_indicators[idx],
"mask": self.masks[idx],
"log_neutral_aa_probs": self.log_neutral_aa_probss[idx],
"nt_rates": self.nt_ratess[idx],
"nt_csps": self.nt_cspss[idx],
}

def to(self, device):
self.aa_parents_idxs = self.aa_parents_idxs.to(device)
self.aa_children_idxs = self.aa_children_idxs.to(device)
self.aa_subs_indicator_tensor = self.aa_subs_indicator_tensor.to(device)
self.mask = self.mask.to(device)
self.log_neutral_aa_probs = self.log_neutral_aa_probs.to(device)
self.aa_parents_idxss = self.aa_parents_idxss.to(device)
self.aa_children_idxss = self.aa_children_idxss.to(device)
self.aa_subs_indicators = self.aa_subs_indicators.to(device)
self.masks = self.masks.to(device)
self.log_neutral_aa_probss = self.log_neutral_aa_probss.to(device)
self.nt_ratess = self.nt_ratess.to(device)
self.nt_cspss = self.nt_cspss.to(device)
if self.multihit_model is not None:
Expand Down
59 changes: 31 additions & 28 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,29 @@ def __init__(
# We have sequences of varying length, so we start with all tensors set
# to the ambiguous amino acid, and then will fill in the actual values
# below.
self.aa_parents_idxs = torch.full(
self.aa_parents_idxss = torch.full(
(pcp_count, self.max_aa_seq_len), MAX_AMBIG_AA_IDX
)
self.aa_children_idxs = self.aa_parents_idxs.clone()
self.aa_subs_indicator_tensor = torch.zeros((pcp_count, self.max_aa_seq_len))
self.aa_children_idxss = self.aa_parents_idxss.clone()
self.aa_subs_indicators = torch.zeros((pcp_count, self.max_aa_seq_len))

self.mask = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool)
self.masks = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool)

for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)):
self.mask[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len)
self.masks[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len)
aa_seq_len = len(aa_parent)
self.aa_parents_idxs[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_parent)
self.aa_children_idxs[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_child)
self.aa_subs_indicator_tensor[i, :aa_seq_len] = aa_subs_indicator_tensor_of(
self.aa_parents_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(
aa_parent
)
self.aa_children_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(
aa_child
)
self.aa_subs_indicators[i, :aa_seq_len] = aa_subs_indicator_tensor_of(
aa_parent, aa_child
)

assert torch.all(self.mask.sum(dim=1) > 0)
assert torch.max(self.aa_parents_idxs) <= MAX_AMBIG_AA_IDX
assert torch.all(self.masks.sum(dim=1) > 0)
assert torch.max(self.aa_parents_idxss) <= MAX_AMBIG_AA_IDX

self._branch_lengths = branch_lengths
self.update_neutral_probs()
Expand Down Expand Up @@ -242,12 +246,12 @@ def update_neutral_probs(self):
"""
neutral_aa_mut_prob_l = []

for nt_parent, mask, nt_rates, branch_length, nt_csps in zip(
for nt_parent, mask, nt_rates, nt_csps, branch_length in zip(
self.nt_parents,
self.mask,
self.masks,
self.nt_ratess,
self._branch_lengths,
self.nt_cspss,
self._branch_lengths,
):
mask = mask.to("cpu")
nt_rates = nt_rates.to("cpu")
Expand All @@ -260,16 +264,15 @@ def update_neutral_probs(self):
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
parent_len = len(nt_parent)
molevol.check_csps(parent_idxs, nt_csps)

mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
normed_nt_csps = molevol.normalize_sub_probs(
parent_idxs, nt_csps[:parent_len, :]
)
nt_csps = nt_csps[:parent_len, :]

neutral_aa_mut_prob = molevol.neutral_aa_mut_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
normed_nt_csps.reshape(-1, 3, 4),
nt_csps.reshape(-1, 3, 4),
multihit_model=multihit_model,
)

Expand Down Expand Up @@ -299,26 +302,26 @@ def update_neutral_probs(self):

# Note that our masked out positions will have a nan log probability,
# which will require us to handle them correctly downstream.
self.log_neutral_aa_mut_probs = torch.log(torch.stack(neutral_aa_mut_prob_l))
self.log_neutral_aa_mut_probss = torch.log(torch.stack(neutral_aa_mut_prob_l))

def __len__(self):
return len(self.aa_parents_idxs)
return len(self.aa_parents_idxss)

def __getitem__(self, idx):
return {
"aa_parents_idxs": self.aa_parents_idxs[idx],
"subs_indicator": self.aa_subs_indicator_tensor[idx],
"mask": self.mask[idx],
"log_neutral_aa_mut_probs": self.log_neutral_aa_mut_probs[idx],
"aa_parents_idxs": self.aa_parents_idxss[idx],
"aa_subs_indicator": self.aa_subs_indicators[idx],
"mask": self.masks[idx],
"log_neutral_aa_mut_probs": self.log_neutral_aa_mut_probss[idx],
"nt_rates": self.nt_ratess[idx],
"nt_csps": self.nt_cspss[idx],
}

def to(self, device):
self.aa_parents_idxs = self.aa_parents_idxs.to(device)
self.aa_subs_indicator_tensor = self.aa_subs_indicator_tensor.to(device)
self.mask = self.mask.to(device)
self.log_neutral_aa_mut_probs = self.log_neutral_aa_mut_probs.to(device)
self.aa_parents_idxss = self.aa_parents_idxss.to(device)
self.aa_subs_indicators = self.aa_subs_indicators.to(device)
self.masks = self.masks.to(device)
self.log_neutral_aa_mut_probss = self.log_neutral_aa_mut_probss.to(device)
self.nt_ratess = self.nt_ratess.to(device)
self.nt_cspss = self.nt_cspss.to(device)
if self.multihit_model is not None:
Expand Down Expand Up @@ -369,7 +372,7 @@ def predictions_of_batch(self, batch):
return self.predictions_of_pair(log_neutral_aa_mut_probs, log_selection_factors)

def loss_of_batch(self, batch):
aa_subs_indicator = batch["subs_indicator"].to(self.device)
aa_subs_indicator = batch["aa_subs_indicator"].to(self.device)
mask = batch["mask"].to(self.device)
aa_subs_indicator = aa_subs_indicator.masked_select(mask)
predictions = self.predictions_of_batch(batch).masked_select(mask)
Expand Down
43 changes: 15 additions & 28 deletions netam/molevol.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
"""Free functions for molecular evolution computation.

We will follow terminology from Yaari et al 2013, where "mutability" refers to the
probability of a nucleotide mutating at a given site, while "substitution" refers to the
probability of a nucleotide mutating to another nucleotide at a given site conditional
on having a mutation.

We assume that the mutation and substitution probabilities already take branch length
into account.
CSP means conditional substitution probability. CSPs are the probabilities of alternate
states conditioned on there being a substitution.
"""

import numpy as np
Expand All @@ -19,41 +14,33 @@
import netam.sequences as sequences


def normalize_sub_probs(parent_idxs: Tensor, sub_probs: Tensor) -> Tensor:
"""Normalize substitution probabilities.

Given a parent DNA sequence and a 2D PyTorch tensor representing substitution
probabilities, this function sets the probability of the actual nucleotide
in the parent sequence to zero and then normalizes each row to form a valid
probability distribution.
def check_csps(parent_idxs: Tensor, csps: Tensor) -> Tensor:
"""Make sure that the CSPs are valid, i.e. that they are a probability distribution
and the parent state is zero.

Args:
parent_idxs (torch.Tensor): The parent sequence indices.
sub_probs (torch.Tensor): A 2D PyTorch tensor representing substitution
probabilities. Rows correspond to sites, and columns correspond
to "ACGT" bases.

Returns:
torch.Tensor: A 2D PyTorch tensor with normalized substitution probabilities.
to states (e.g. nucleotides).
"""

# Assert that sub_probs are within the range [0, 1] modulo rounding error
assert torch.all(csps >= -1e-6), "Substitution probabilities must be non-negative"
assert torch.all(
sub_probs >= -1e-6
), "Substitution probabilities must be non-negative"
assert torch.all(
sub_probs <= 1 + 1e-6
csps <= 1 + 1e-6
), "Substitution probabilities must be less than or equal to 1"

# Create an array of row indices that matches the shape of `parent_idxs`.
row_indices = torch.arange(len(parent_idxs))

# Set the entries corresponding to the parent sequence to zero.
sub_probs[row_indices, parent_idxs] = 0.0

# Normalize the probabilities.
row_sums = torch.sum(sub_probs, dim=1, keepdim=True)
return sub_probs / row_sums
# Assert that the parent entry has a substitution probability of nearly 0.
assert torch.all(
csps[row_indices, parent_idxs] < 1e-6
), "Parent entry must have a substitution probability of nearly 0"
assert torch.allclose(
csps[: len(parent_idxs)].sum(dim=1), torch.ones(len(parent_idxs))
)


def build_mutation_matrices(
Expand Down
19 changes: 8 additions & 11 deletions tests/test_molevol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
import pytest

import netam.molevol as molevol
from netam import framework

Expand Down Expand Up @@ -99,19 +101,14 @@ def test_neutral_aa_mut_probs():
assert torch.allclose(correct_tensor, computed_tensor)


def test_normalize_sub_probs():
def test_check_csps():
parent_idxs = nt_idx_tensor_of_str("AC")
sub_probs = torch.tensor([[0.2, 0.3, 0.4, 0.1], [0.1, 0.2, 0.3, 0.4]])

expected_normalized = torch.tensor(
[[0.0, 0.375, 0.5, 0.125], [0.125, 0.0, 0.375, 0.5]]
)
normalized_sub_probs = molevol.normalize_sub_probs(parent_idxs, sub_probs)
csp = torch.tensor([[0.0, 0.375, 0.5, 0.125], [0.125, 0.0, 0.375, 0.5]])
molevol.check_csps(parent_idxs, csp)

assert normalized_sub_probs.shape == (2, 4), "Result has incorrect shape"
assert torch.allclose(
normalized_sub_probs, expected_normalized
), "Unexpected normalized values"
not_csp = torch.tensor([[0.2, 0.3, 0.4, 0.1], [0.1, 0.2, 0.3, 0.4]])
with pytest.raises(AssertionError):
molevol.check_csps(parent_idxs, not_csp)


def iterative_aaprob_of_mut_and_sub(parent_codon, mut_probs, csps):
Expand Down